diff --git a/mindspeed_llm/training/arguments.py b/mindspeed_llm/training/arguments.py index 0270e31ebe44d7168eb5916db5d6bdd3d3c3f144..c2d1474becb41a5d4a5427a22689d3987ce78fe3 100644 --- a/mindspeed_llm/training/arguments.py +++ b/mindspeed_llm/training/arguments.py @@ -509,6 +509,8 @@ def _add_data_args(parser): help="Convert input-layernorm to fp32") group.add_argument("--no-shuffle", action='store_true', help="Disable data shuffling, mainly for loss comparison.") + group.add_argument('--enable-share-memory', action='store_true', default=False, + help='Enable shared memory for passing actual_seq_len when reset-position-ids is enabled.') group.add_argument('--neat-pack', action='store_true', help='Use a zigzag attention mask.') group.add_argument('--padded-samples', action='store_true', @@ -517,6 +519,13 @@ def _add_data_args(parser): return parser +def _validate_share_memory_args(args): + if args.enable_share_memory and not args.reset_position_ids: + raise AssertionError('Shared memory requires --reset-position-ids to be enabled.') + if args.enable_share_memory and args.position_embedding_type == 'alibi': + raise AssertionError('Shared memory is not supported with alibi position embeddings.') + + def _add_num_layer_allocation(parser): group = parser.add_argument_group(title='num_layer_allocation') group.add_argument('--num-layer-list', @@ -1013,7 +1022,7 @@ def _add_dataset_args(parser): default=[], help='Additional keys need to be add from dataset.' ) - + return parser @@ -1080,7 +1089,7 @@ def _validate_recompute_args(args): raise AssertionError('uniform recomputation is not compatible with activation function recomputation.') if args.recompute_granularity == "selective": raise AssertionError('--recompute-activation-function is not compatible with selective recomputation.') - + if args.recompute_norm: if args.recompute_method == "uniform": raise AssertionError('uniform recomputation is not compatible with norm recomputation.') @@ -1088,7 +1097,7 @@ def _validate_recompute_args(args): raise AssertionError('--recompute-norm is not compatible with selective recomputation') if not args.use_mcore_models: raise AssertionError('--recompute-norm is only supported with mcore models') - + if args.swap_attention and args.swap_modules is None: if args.use_mcore_models: args.swap_modules = "input_layernorm,self_attention,pre_cross_attn_layernorm" @@ -1581,6 +1590,7 @@ def validate_args_decorator(megatron_validate_args): _validate_create_attention_mask_in_dataloader(args) _validate_instruction_finetune(args) _validate_position_embedding(args) + _validate_share_memory_args(args) _validate_high_availability(args) _validate_inference_args(args) _validate_moe_args(args) @@ -1611,7 +1621,7 @@ def validate_args_decorator(megatron_validate_args): feature.pre_validate_args(args) feature.validate_args(args) feature.post_validate_args(args) - + from mindspeed_llm.training.utils import print_args print_args('MindSpeed-LLM Arguments', args) return args diff --git a/mindspeed_llm/training/utils.py b/mindspeed_llm/training/utils.py index 47114829d68ee6109cf92e572aa2b7fb99a2aabc..31ffec60a958fb5bdde6fa6e618d1b1760d44ccc 100644 --- a/mindspeed_llm/training/utils.py +++ b/mindspeed_llm/training/utils.py @@ -16,10 +16,13 @@ """General utilities.""" import os import stat +import time +import atexit import random import warnings from functools import wraps from typing import Optional, Union, List +from multiprocessing.shared_memory import SharedMemory from itertools import takewhile import torch @@ -46,7 +49,7 @@ WRITE_FILE_DEFAULT_FLAGS = os.O_WRONLY | os.O_CREAT WRITE_FILE_DEFAULT_MODES = stat.S_IWUSR | stat.S_IRUSR -def compute_actual_seq_len(origin_seq): +def _compute_actual_seq_len(origin_seq): seq = origin_seq.view(-1) zero_pos = (seq == 0).nonzero()[1:].squeeze(dim=1) res = zero_pos.tolist() @@ -54,14 +57,11 @@ def compute_actual_seq_len(origin_seq): return res -def generate_actual_seq_len(batch): +def compute_actual_seq_len(origin_seq): args = get_args() - position_ids = batch.get('position_ids').transpose(0, 1).contiguous() - set_position_ids(position_ids) - position_ids = batch.get('position_ids') - actual_seq_len = compute_actual_seq_len(position_ids) + actual_seq_len = _compute_actual_seq_len(origin_seq) if args.mtp_num_layers: - seq_len = position_ids.shape[1] + seq_len = origin_seq.shape[1] mtp_res = [actual_seq_len] for i in range(1, args.mtp_num_layers + 1): next_actual_seq_len = [] @@ -71,8 +71,18 @@ def generate_actual_seq_len(batch): else: next_actual_seq_len.append(j - i) mtp_res.append(next_actual_seq_len) - set_actual_seq_len(mtp_res) + return mtp_res + return actual_seq_len + + +def generate_actual_seq_len(batch, actual_seq_len=None): + position_ids = batch.get('position_ids').transpose(0, 1).contiguous() + set_position_ids(position_ids) + if actual_seq_len is not None: + set_actual_seq_len(actual_seq_len) else: + position_ids = batch.get('position_ids') + actual_seq_len = compute_actual_seq_len(position_ids) set_actual_seq_len(actual_seq_len) @@ -256,8 +266,362 @@ def get_finetune_data_on_this_tp_rank(data_iterator): return tokens, attention_mask +_GLOBAL_SHM_MANAGER = None # 共享内存管理器实例 +_SHM_SKIP_FLAG = False # 是否不使用共享内存 +BASE_SHM_NAME = "g_shm" + + +def reset_sharedmem_mgr(): + """ + 重置共享内存管理器及状态标志。 + """ + global _GLOBAL_SHM_MANAGER, _SHM_SKIP_FLAG + + if _GLOBAL_SHM_MANAGER is not None: + try: + _GLOBAL_SHM_MANAGER.close() + except Exception as e: + print(f"[SharedMemoryManager] [WARN] Error during SharedMemoryManager shutdown: {e}") + + _GLOBAL_SHM_MANAGER = None + _SHM_SKIP_FLAG = False + + +def get_sharedmem_mgr(base_shm_name="g_shm", buffer_length=4096): + """ + 获取全局共享内存管理器,通过共享内存传递数据。 + :param base_shm_name: 共享内存的基础名称 + :param buffer_length: 共享内存 buffer 大小, 默认值: 4K + :return: `SharedMemoryManager` + """ + global _GLOBAL_SHM_MANAGER, _SHM_SKIP_FLAG + + if _SHM_SKIP_FLAG: + return None + + if _GLOBAL_SHM_MANAGER is not None: + return _GLOBAL_SHM_MANAGER + + rank = mpu.get_tensor_model_parallel_rank() + global_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else -1 + + if not torch.distributed.is_initialized(): + print( + f"[SharedMemoryManager][Rank {rank}][global_rank {global_rank}]" + f"[Func: get_sharedmem_mgr] " + f"torch.distributed not initialized, skipping..." + ) + return None + + args = get_args() + reset_position_ids = args.reset_position_ids + enable_shm = args.enable_share_memory + tp_size = mpu.get_tensor_model_parallel_world_size() + device_count = torch.cuda.device_count() + + if not (reset_position_ids and enable_shm and tp_size > 1 and tp_size <= device_count): + print( + f"[SharedMemoryManager][Rank {rank}][global_rank {global_rank}]" + f"[Func: get_sharedmem_mgr] Skip creation. " + f"reset_position_ids={reset_position_ids}, enable_shm={enable_shm}, " + f"tp_size={tp_size}, device_count={device_count}" + ) + _SHM_SKIP_FLAG = True + return None + + if rank == 0: + pid = os.getpid() + _GLOBAL_SHM_MANAGER = SharedMemoryManager( + base_shm_name, rank0_pid=pid, buffer_length=buffer_length, tp_size=tp_size + ) + print( + f"[SharedMemoryManager][Rank {rank}][global_rank {global_rank}] Created: " + f"{_GLOBAL_SHM_MANAGER.shm_name}, TP_size: {tp_size}, TP_Group: {_GLOBAL_SHM_MANAGER.tp_group_id}" + ) + + try: + torch.distributed.barrier(group=mpu.get_tensor_model_parallel_group()) + except RuntimeError as e: + print( + f"[SharedMemoryManager][Rank {rank}][global_rank {global_rank}]" + f"[Func: get_sharedmem_mgr] Barrier timeout: {e}" + ) + + if rank == 0: + pid = os.getpid() + pid_tensor = torch.tensor([pid], dtype=torch.int32, device="cuda") + torch.distributed.broadcast(pid_tensor, mpu.get_tensor_model_parallel_src_rank(), + group=mpu.get_tensor_model_parallel_group()) + else: + pid_tensor = torch.zeros(1, dtype=torch.int32, device="cuda") + torch.distributed.broadcast(pid_tensor, mpu.get_tensor_model_parallel_src_rank(), + group=mpu.get_tensor_model_parallel_group()) + pid = pid_tensor.item() + _GLOBAL_SHM_MANAGER = SharedMemoryManager( + base_shm_name, rank0_pid=pid, buffer_length=buffer_length, tp_size=tp_size, existing=True + ) + print( + f"[SharedMemoryManager][Rank {rank}][global_rank {global_rank}] Connected to: " + f"{_GLOBAL_SHM_MANAGER.shm_name}, TP_size: {tp_size}, TP_Group: {_GLOBAL_SHM_MANAGER.tp_group_id}" + ) + + torch.distributed.barrier(group=mpu.get_tensor_model_parallel_group()) + return _GLOBAL_SHM_MANAGER + + +class SharedMemoryManager: + def __init__(self, base_shm_name, rank0_pid, buffer_length, tp_size, + existing=False, timeout=3600.0, sleep_time=0.001): + """ + :param base_shm_name: 共享内存的基础名称(多个 TP 组时,每个组有独立共享内存) + :param buffer_length: 共享内存 buffer 大小(以 int64 计) + :param tp_size: TP 组大小(每个 TP 组的进程数) + :param existing: 是否连接已存在的共享内存 + :param timeout: 读写操作的超时时间(秒) + :param sleep_time: 读写等待时的 sleep 时间(秒) + """ + self.buffer_length = buffer_length + self.tp_size = tp_size + self.timeout = timeout + self.sleep_time = sleep_time + self.int64_size = torch.tensor(0, dtype=torch.int64).element_size() + + self.rank = mpu.get_tensor_model_parallel_rank() + self.global_rank = torch.distributed.get_rank() + self.tp_group_id = self.global_rank // self.tp_size + + if rank0_pid is None: + raise ValueError("SharedMemoryManager requires rank0_pid to construct shm_name.") + self.shm_name = self.generate_shm_name(base_shm_name, rank0_pid, self.tp_group_id) + + print(f"[SharedMemoryManager][Rank {self.rank}] Using shm_name: {self.shm_name}") + + self.total_size = ( + buffer_length * self.int64_size + (tp_size + 3) * self.int64_size + ) + + if not existing: + try: + existing_shm = SharedMemory(name=self.shm_name) + existing_shm.close() + existing_shm.unlink() + import multiprocessing.resource_tracker as rt + rt.unregister(self.shm_name, "shared_memory") + print(f"[SharedMemoryManager][Rank {self.rank}] Unlinked residual shared memory '{self.shm_name}'.") + except FileNotFoundError: + pass + except Exception as e: + print(f"[SharedMemoryManager][Rank {self.rank}] Failed to unlink residual shared memory: {e}") + + self.shm = SharedMemory( + name=self.shm_name, + create=not existing, + size=self.total_size if not existing else 0 + ) + + offset = 0 + self.tensor = torch.frombuffer( + self.shm.buf[offset:offset + buffer_length * self.int64_size], + dtype=torch.int64 + ).view((buffer_length,)) + offset += buffer_length * self.int64_size + + self.seq_len_real_length = torch.frombuffer( + self.shm.buf[offset:offset + self.int64_size], + dtype=torch.int64 + ) + offset += self.int64_size + + self.seq_len_num = torch.frombuffer( + self.shm.buf[offset:offset + self.int64_size], + dtype=torch.int64 + ) + offset += self.int64_size + + self.read_flags = torch.frombuffer( + self.shm.buf[offset:offset + tp_size * self.int64_size], + dtype=torch.int64 + ) + offset += tp_size * self.int64_size + + self.data_version = torch.frombuffer( + self.shm.buf[offset:], + dtype=torch.int64 + ) + + if not existing: + self.read_flags.zero_() + self.data_version.zero_() + self.seq_len_real_length.zero_() + self.seq_len_num.zero_() + + self.local_version = self.data_version.item() + + # 注册自动销毁共享内存的机制 + atexit.register(self.close) + + @staticmethod + def generate_shm_name(base_name, rank0_pid, tp_group_id): + return f"{base_name}_pid{rank0_pid}_tp{tp_group_id}" + + def write(self, data): + if self.rank != 0 or self.tp_size == 1: + self.read_flags[self.rank] = 1 + return + + start_time = time.time() + last_log_time = start_time + while self.data_version.item() > 0 and self.read_flags.sum().item() < self.tp_size: + elapsed_time = time.time() - start_time + if elapsed_time > self.timeout: + print( + f"[SharedMemoryManager][Rank {self.rank}]" + f"[global_rank {self.global_rank}][Func: write] " + f"Timeout: other ranks did not read data in time. " + f"read_flags: {self.read_flags.tolist()}" + ) + self.read_flags[self.rank] = 1 + return + + if elapsed_time - last_log_time > 60.0: + print( + f"[SharedMemoryManager][Rank {self.rank}]" + f"[global_rank {self.global_rank}][Func: write] Waiting... " + f"Elapsed: {elapsed_time:.2f}s, " + f"read_flags sum = {self.read_flags.sum().item()} / {self.tp_size}" + ) + last_log_time = time.time() + time.sleep(self.sleep_time) + + if isinstance(data, list): + if isinstance(data[0], torch.Tensor): + data = [item.numpy() for item in data] + data = torch.tensor(data, dtype=torch.int64) + + real_length = data.numel() if data is not None else 0 + seq_len_num = data.shape[0] if data is not None and len(data.shape) >= 2 else 0 + self.read_flags.zero_() + + if data is None or real_length == 0: + print( + f"[SharedMemoryManager][Rank {self.rank}]" + f"[global_rank {self.global_rank}][Func: write] " + f"Writing None, setting seq_len_real_length=-1" + ) + self.seq_len_real_length.fill_(-1) + self.seq_len_num.fill_(-1) + else: + self.tensor[:real_length].copy_(data.view(-1)[:real_length]) + self.tensor[real_length:].fill_(0) + self.seq_len_real_length.fill_(real_length) + self.seq_len_num.fill_(seq_len_num) + + self.data_version.add_(1) + self.read_flags[self.rank] = 1 + + def read(self): + if self.rank == 0 or self.tp_size == 1: + self.read_flags[self.rank] = 1 + return None + + start_time = time.time() + last_log_time = start_time + while self.data_version.item() <= self.local_version: + elapsed_time = time.time() - start_time + if elapsed_time > self.timeout: + print( + f"[SharedMemoryManager][Rank {self.rank}]" + f"[global_rank {self.global_rank}][Func: read] Timeout: No new data. " + f"data_version={self.data_version.item()}, " + f"local_version={self.local_version}" + ) + self.read_flags[self.rank] = 1 + return None + + if time.time() - last_log_time > 60.0: + print( + f"[SharedMemoryManager][Rank {self.rank}]" + f"[global_rank {self.global_rank}][Func: read] Still waiting... " + f"Elapsed: {elapsed_time:.2f}s, " + f"data_version={self.data_version.item()}, " + f"expected version > {self.local_version}" + ) + last_log_time = time.time() + time.sleep(self.sleep_time) + + real_length = self.seq_len_real_length.item() + seq_len_num = self.seq_len_num.item() + if real_length == -1: + print( + f"[SharedMemoryManager][Rank {self.rank}]" + f"[global_rank {self.global_rank}][Func: read] " + f"Detected None data (real_length=-1)" + ) + data = None + else: + if seq_len_num <= 1: + data = self.tensor[:real_length].clone() + else: + data = self.tensor[:real_length].clone().view(seq_len_num, -1) + + self.local_version = self.data_version.item() + self.read_flags[self.rank] = 1 + + if isinstance(data, torch.Tensor): + data = data.tolist() + return data + + def close(self): + if self.rank == 0: + start_time = time.time() + while self.read_flags.sum().item() < self.tp_size: + if time.time() - start_time > self.timeout: + print( + f"[SharedMemoryManager][Rank {self.rank}]" + f"[global_rank {self.global_rank}][Func: close] " + f"Timeout waiting for ranks to finish reading. " + f"read_flags: {self.read_flags.tolist()}" + ) + break + time.sleep(self.sleep_time) + + del self.tensor + del self.seq_len_real_length + del self.seq_len_num + del self.read_flags + del self.data_version + + import gc + gc.collect() + time.sleep(0.1) + + try: + self.shm.close() + if self.rank == 0: + self.shm.unlink() + print( + f"[SharedMemoryManager][Rank {self.rank}]" + f"[global_rank {self.global_rank}][Func: close] " + f"Shared memory '{self.shm_name}' released and unlinked." + ) + else: + import multiprocessing.resource_tracker as rt + rt.unregister(self.shm._name, "shared_memory") + + except Exception as e: + print( + f"[SharedMemoryManager][Rank {self.rank}]" + f"[global_rank {self.global_rank}][Func: close] " + f"Cleanup error during shm close/unlink: {e}" + ) + def get_batch_on_this_tp_rank(data_iterator): + batch, _ = get_batch_on_this_tp_rank_expand(data_iterator) + return batch + + +def get_batch_on_this_tp_rank_expand(data_iterator): args = get_args() def _broadcast(item): @@ -265,12 +629,22 @@ def get_batch_on_this_tp_rank(data_iterator): torch.distributed.broadcast(item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()) + shm_manager = None + actual_seq_len = None + if args.enable_share_memory: + shm_manager = get_sharedmem_mgr(BASE_SHM_NAME, args.micro_batch_size * args.seq_length) + if mpu.get_tensor_model_parallel_rank() == 0: if data_iterator is not None: data = next(data_iterator) else: data = None + if args.enable_share_memory and shm_manager is not None: + position_ids = data["position_ids"] + actual_seq_len = compute_actual_seq_len(position_ids) + shm_manager.write(actual_seq_len) + if args.return_document_ids and mpu.get_context_parallel_rank() == 0 and mpu.get_pipeline_model_parallel_rank() == 0: document_ids = [ [x.item() for x in takewhile(lambda y: y.item() != -100, row)] @@ -333,6 +707,8 @@ def get_batch_on_this_tp_rank(data_iterator): _broadcast(batch['position_ids']) else: + if args.enable_share_memory and shm_manager is not None: + actual_seq_len = shm_manager.read() tokens = torch.empty((args.micro_batch_size, args.seq_length), dtype=torch.int64, @@ -404,7 +780,7 @@ def get_batch_on_this_tp_rank(data_iterator): 'position_ids': position_ids } - return batch + return batch, actual_seq_len def get_batch_on_this_cp_rank(batch): diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 196bec682265b5e01d926f8f27aea36f75589f1b..c8aa94fe964f9dca74e2fe721338418eac6158c2 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -33,7 +33,7 @@ from megatron.core.models.gpt.gpt_layer_specs import ( get_gpt_layer_local_spec, get_gpt_layer_with_transformer_engine_spec, ) -from mindspeed_llm.training.utils import generate_actual_seq_len, tensor_slide +from mindspeed_llm.training.utils import generate_actual_seq_len, tensor_slide, get_batch_on_this_tp_rank_expand def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megatron.legacy.model.GPTModel]: @@ -104,16 +104,21 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat def get_batch(data_iterator): """Generate a batch.""" - # get batches based on the TP rank you are on - batch = get_batch_on_this_tp_rank(data_iterator) args = get_args() + actual_seq_len = None + # get batches based on the TP rank you are on + if args.enable_share_memory: + batch, actual_seq_len = get_batch_on_this_tp_rank_expand(data_iterator) + else: + batch = get_batch_on_this_tp_rank(data_iterator) + if args.return_document_ids and mpu.get_context_parallel_rank() == 0 and mpu.get_tensor_model_parallel_rank() == 0 and mpu.get_pipeline_model_parallel_rank() == 0: print("current idx: {}, current rank: {}, data_parallel_rank: {}, document_ids: {}".format(batch['idx'], torch.distributed.get_rank(), mpu.get_data_parallel_rank(), batch['document_ids'])) batch.pop('document_ids', None) batch.pop('idx', None) if args.reset_position_ids: - generate_actual_seq_len(batch) + generate_actual_seq_len(batch, actual_seq_len) # slice batch along sequence dimension for context parallelism batch = get_batch_on_this_cp_rank(batch) return batch.values()