diff --git a/mindspeed_rl/config_cls/rl_config.py b/mindspeed_rl/config_cls/rl_config.py index 650a5db15c1812c6cdf282130b2aa5e3100241a7..2c055601f3e490a94fd05458ae370307e65ed71e 100644 --- a/mindspeed_rl/config_cls/rl_config.py +++ b/mindspeed_rl/config_cls/rl_config.py @@ -46,6 +46,8 @@ class RLConfig(BaseConfig): wandb_save_dir: Path to save the wandb results locally. (default: "") blocking: Whether to enable blocking mode (default: False) num_cpus_for_local_task: Number of CPUs for local ray task (default: 1) + + use_remove_padding: Whether to use packed sequences for forward (default: False) # Default values can still be defined if no config is provided ''' @@ -94,5 +96,7 @@ class RLConfig(BaseConfig): self.wandb_save_dir = "" self.blocking = False self.num_cpus_for_local_task = 1 + + self.use_remove_padding = False self.update(config_dict) diff --git a/mindspeed_rl/config_cls/validate_config.py b/mindspeed_rl/config_cls/validate_config.py index bba4ae9657239048e10d203fb1452b657d4989b0..1625b1b68f7d04624820f07b68f9228bf0e921ec 100644 --- a/mindspeed_rl/config_cls/validate_config.py +++ b/mindspeed_rl/config_cls/validate_config.py @@ -10,6 +10,14 @@ def validate_rl_args(actor_config, ref_config, reward_config, rl_config, generat f"Actor.seq_length={actor_config.seq_length} vs " f"GenerateConfig.max_model_len={generate_config.max_model_len}") + # 校验移除填充特性相关配置 + if (rl_config.use_remove_padding and actor_config.pipeline_model_parallel_size > 1) and not actor_config.variable_seq_lengths: + raise ValueError( + "'use_remove_padding' feature requires 'variable_seq_lengths=True' when using pipeline parallelism!" + "If you want to use context parallelism under this premise and encounter the mindspeed_llm validation error about variable_seq_lengths, " + "you just need to delete the validation code of mindspeed_llm, and it will not cause problems." + ) + # 初始化经验计数配置 rl_config.experience_count_actor = rl_config.experience_count_actor or rl_config.experience_count rl_config.experience_count_ref = rl_config.experience_count_ref or rl_config.experience_count diff --git a/mindspeed_rl/models/base/base_training_engine.py b/mindspeed_rl/models/base/base_training_engine.py index e72a1a6dc4bfe66bf83875798c5777dfefd0c57a..e44ab3a4b6f0e7de790b03f78770d885ee226f0d 100644 --- a/mindspeed_rl/models/base/base_training_engine.py +++ b/mindspeed_rl/models/base/base_training_engine.py @@ -10,7 +10,12 @@ from torch.utils.data import DataLoader from mindspeed_rl.models.loss.base_loss_func import BaseLossFunc from mindspeed_rl.models.loss.loss_func_factory import LossFuncFactory from mindspeed_rl.utils.utils import ( - append_to_dict, generate_mask, generate_position_ids, get_tune_attention_mask + append_to_dict, + generate_mask, + generate_position_ids, + get_tune_attention_mask, + preprocess_packed_seqs, + postprocess_packed_seqs ) @@ -49,6 +54,7 @@ class BaseTrainingEngine(ABC): clip_ratio: float = 0.1, role: str = None, micro_batch_size: int = 1, + use_remove_padding: bool = False, forward_backward_func: Callable = None, **kwargs): self.forward_backward_func = forward_backward_func @@ -64,6 +70,7 @@ class BaseTrainingEngine(ABC): self.role = role self.kl_ctrl = kl_ctrl self.clip_ratio = clip_ratio + self.use_remove_padding = use_remove_padding self.loss_func: BaseLossFunc = LossFuncFactory.get_instance(self.stage, self.role) self.kwargs = kwargs @@ -88,9 +95,23 @@ class BaseTrainingEngine(ABC): self.loss_func.add_loss_meta_info(self.get_loss_meta_func()) + from megatron.core import mpu + post_process = mpu.get_pipeline_model_parallel_world_size() == 1 or mpu.is_pipeline_last_stage() + def forward_step(batch_iter, model): - input_ids, attention_mask, position_ids, process_batch = self._get_forward_batch_info(batch_iter) - output = model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids) + if self.use_remove_padding: + input_ids, position_ids, process_batch, seqlens_in_batch, cu_seqlens_padded = self._get_forward_batch_info(batch_iter) + output_orig = model(input_ids=input_ids, attention_mask=None, position_ids=position_ids) + if not post_process: + output = output_orig + else: + output = postprocess_packed_seqs(output=output_orig, + seqlens_in_batch=seqlens_in_batch, + cu_seqlens_padded=cu_seqlens_padded, + seq_len=seq_len) + else: + input_ids, attention_mask, position_ids, process_batch = self._get_forward_batch_info(batch_iter) + output = model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids) return output, partial(self.loss_func.compute_loss, batch=process_batch, forward_only=forward_only) # batch should be a list of batches inside micro-batches @@ -99,8 +120,8 @@ class BaseTrainingEngine(ABC): data_iterator=iter(batches), model=self.model, num_microbatches=n_micro_batch, - seq_length=seq_len, - micro_batch_size=self.micro_batch_size, + seq_length=self.micro_batch_size * seq_len if self.use_remove_padding else seq_len, + micro_batch_size=1 if self.use_remove_padding else self.micro_batch_size, forward_only=forward_only, collect_non_loss_data=forward_only, ) @@ -113,15 +134,21 @@ class BaseTrainingEngine(ABC): """ return {} - @staticmethod - def _get_forward_batch_info(batch_iter): + def _get_forward_batch_info(self, batch_iter): batch = next(batch_iter) input_ids = batch['input_ids'] attention_mask_1d = generate_mask(input_ids, batch['prompt_length'] + batch['response_length']).to( input_ids.device) - position_ids = torch.tensor(generate_position_ids(input_ids)).to(input_ids.device) - attention_mask = get_tune_attention_mask(attention_mask_1d) - return input_ids, attention_mask, position_ids, batch + if self.use_remove_padding: + from megatron.core import parallel_state + tp_size = parallel_state.get_tensor_model_parallel_world_size() + input_ids, position_ids, seqlens_in_batch, cu_seqlens_padded = preprocess_packed_seqs( + input_ids=input_ids, attention_mask_1d=attention_mask_1d, tp_size=tp_size) + return input_ids, position_ids, batch, seqlens_in_batch, cu_seqlens_padded + else: + position_ids = torch.tensor(generate_position_ids(input_ids)).to(input_ids.device) + attention_mask = get_tune_attention_mask(attention_mask_1d) + return input_ids, attention_mask, position_ids, batch def post_process_forward_backward_output(self, output: [torch.Tensor], batch: Dict[str, torch.Tensor]) -> torch.Tensor: diff --git a/mindspeed_rl/utils/utils.py b/mindspeed_rl/utils/utils.py index 117b9e25bfbef54656499be40599af6a1442e68f..d759b07ec65db17d76588448b2a0bdcf25791a61 100644 --- a/mindspeed_rl/utils/utils.py +++ b/mindspeed_rl/utils/utils.py @@ -6,7 +6,7 @@ import sys import time import random -from typing import Dict, List +from typing import Dict, List, Tuple import omegaconf import numpy as np @@ -79,6 +79,114 @@ def get_tune_attention_mask(attention_mask_1d, reset_attention_mask=True, tokeni return attention_mask +def preprocess_packed_seqs( + input_ids: torch.Tensor, + attention_mask_1d: torch.Tensor, + tp_size: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Packs variable-length sequences from a batch into a single contiguous tensor for efficient processing. + + Parameters: + input_ids (torch.Tensor): Tensor of shape (batch_size, seq_len) containing token IDs. + attention_mask_1d (torch.Tensor): Binary mask tensor of shape (batch_size, seq_len) where + each entry indicates valid token positions (1) vs padding (0). dtype should be torch.int or torch.bool. + tp_size (int): Alignment factor for packing; sequences are padded so that their lengths are + are multiples of this size. + + Returns: + input_ids_packed (torch.Tensor): Tensor of shape (1, pack_length) with all valid tokens packed sequentially. + position_ids_packed (torch.Tensor): Tensor of shape (1, pack_length) containing positional + indices within each padded sequence block. + seqlens_in_batch (torch.Tensor): 1D int32 tensor of shape (batch_size,) with original + sequence lengths (number of valid tokens per sample). + cu_seqlens_padded (torch.Tensor): 1D int32 tensor of shape (batch_size+1,) containing + cumulative padded sequence lengths, used for indexing into the packed tensor. + + Raises: + ValueError: If input_ids and attention_mask_1d have incompatible shapes. + """ + batch_size, seq_len = input_ids.shape + if attention_mask_1d.shape != (batch_size, seq_len): + raise ValueError("attention_mask_1d must have shape (batch_size, seq_len) matching input_ids") + + # Compute actual sequence lengths per sample + seqlens_in_batch = attention_mask_1d.sum(dim=1, dtype=torch.int32) + # Compute padding needed to align lengths to tp_size + pad_size = (tp_size - (seqlens_in_batch % tp_size)) % tp_size + seqlens_in_batch_padded = seqlens_in_batch + pad_size + + # Cumulative lengths without and with padding + cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device) + cu_seqlens[1:] = torch.cumsum(seqlens_in_batch, dim=0) + cu_seqlens_padded = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device) + cu_seqlens_padded[1:] = torch.cumsum(seqlens_in_batch_padded, dim=0) + + # Total packed length after padding + pack_length = int(seqlens_in_batch_padded.sum().item()) + input_ids_packed = torch.zeros(pack_length, dtype=input_ids.dtype, device=input_ids.device) + # Copy valid tokens sequentially + for i in range(batch_size): + start = cu_seqlens_padded[i].item() + length = seqlens_in_batch[i].item() + input_ids_packed[start:start + length] = input_ids[i, :length] + + # Generate position IDs within each padded segment + position_ids_packed = torch.zeros(pack_length, dtype=torch.int32, device=input_ids.device) + for i in range(batch_size): + start = cu_seqlens_padded[i].item() + end = cu_seqlens_padded[i + 1].item() + position_ids_packed[start:end] = torch.arange( + end - start, dtype=torch.int32, device=input_ids.device + ) + + return ( + input_ids_packed.unsqueeze(0), + position_ids_packed.unsqueeze(0), + seqlens_in_batch, + cu_seqlens_padded + ) + + +def postprocess_packed_seqs( + output: torch.Tensor, + seqlens_in_batch: torch.Tensor, + cu_seqlens_padded: torch.Tensor, + seq_len: int +) -> torch.Tensor: + """ + Unpacks a packed output tensor back into the original batch shape, restoring padding. + + Parameters: + output (torch.Tensor): Packed tensor of shape (1, pack_length, ...), typically the model output. + seqlens_in_batch (torch.Tensor): 1D int32 tensor of original sequence lengths, shape (batch_size,). + cu_seqlens_padded (torch.Tensor): 1D int32 tensor of cumulative padded lengths, shape (batch_size+1,). + batch_size (int): Original batch size. + seq_len (int): Maximum sequence length (including padding) for the output reconstruction. + + Returns: + output_new (torch.Tensor): Tensor of shape (batch_size, seq_len, ...), with original outputs + in the first seqlens_in_batch positions and zeros for padding positions. + + Raises: + ValueError: If seqlens_in_batch length does not match batch_size. + """ + if seqlens_in_batch.shape[0] != batch_size: + raise ValueError("Length of seqlens_in_batch must equal batch_size") + + # Prepare new output with padding + batch_size = seqlens_in_batch.shape[0] + full_shape = [batch_size, seq_len] + list(output.shape[2:]) + output_new = torch.zeros(full_shape, dtype=output.dtype, device=output.device) + + for i in range(batch_size): + start = cu_seqlens_padded[i].item() + length = seqlens_in_batch[i].item() + output_new[i, :length] = output[0, start:start + length] + + return output_new + + def append_to_dict(data: Dict, new_data: Dict): for key, val in new_data.items(): if key not in data: diff --git a/mindspeed_rl/workers/actor_hybrid_worker.py b/mindspeed_rl/workers/actor_hybrid_worker.py index 814a0d02593fef39416a88d6a28e39ecf27055c8..641cd811f37ad3f82133626465456e53d2863234 100644 --- a/mindspeed_rl/workers/actor_hybrid_worker.py +++ b/mindspeed_rl/workers/actor_hybrid_worker.py @@ -96,7 +96,8 @@ class ActorHybridWorker(BaseWorker): stage=self.megatron_config.stage, forward_backward_func=self.forward_backward_func, clip_ratio=self.rl_config.clip_ratio, - micro_batch_size=self.megatron_config.micro_batch_size + micro_batch_size=self.megatron_config.micro_batch_size, + use_remove_padding=self.rl_config.use_remove_padding ) self.empty_cache() diff --git a/mindspeed_rl/workers/integrated_worker.py b/mindspeed_rl/workers/integrated_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..970f301a1bcec65b584b594cf2772be00166066f --- /dev/null +++ b/mindspeed_rl/workers/integrated_worker.py @@ -0,0 +1,179 @@ +# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. + +import dataclasses +from typing import Callable + +import time +import ray +import torch + +from mindspeed_rl.config_cls.megatron_config import MegatronConfig +from mindspeed_rl.config_cls.rl_config import RLConfig +from mindspeed_rl.config_cls.generate_config import GenerateConfig +from mindspeed_rl.utils.tokenizer import BaseTokenizer +from mindspeed_rl.workers.resharding.megatron_sharding_manager import MegatronOffLoader + +from mindspeed_rl.workers.actor_hybrid_worker import ActorHybridWorkerBase +from mindspeed_rl.workers.reference_woker import ReferenceWorkerBase +from mindspeed_rl.workers.reward_woker import RewardWorkerBase +from mindspeed_rl.models.reference import Reference +from mindspeed_rl.models.reward import Reward + + +@ray.remote(resources={"NPU": 0.7}) +class IntegratedWorker(ActorHybridWorkerBase, ReferenceWorkerBase, RewardWorkerBase): + """ + IntegratedWorker class. This class implements the integrated worker for the Actor, Reference and Reward Worker. + + Args: + megatron_config: MegatronConfig Configuration for Megatron-LM (e.g., model parallelism settings). + rl_config: RLConfig Configuration for reinforcement learning (e.g., PPO settings). + generate_config: GenerateConfig Configuration for generation/inference (e.g., vLLM settings). + model_provider: Callable Function to provide the model instance. + initialize_func: Callable Function to initialize the model and environment. + tokenizer: BaseTokenizer = None Object to retrieve the tokenizer. + get_megatron_module: Callable = megatron_module from get_megatron_module. + **kwargs: Additional parameters for base class argument passing. + """ + + def __init__( + self, + megatron_config: MegatronConfig, + rl_config: RLConfig, + generate_config: GenerateConfig, + model_provider: Callable, + initialize_func: Callable, + tokenizer: BaseTokenizer = None, + get_megatron_module: Callable = None, + **kwargs + ): + + # We use Actor as main worker, so only do init for Actor here. + ActorHybridWorkerBase.__init__( + self, + megatron_config, + rl_config, + generate_config, + model_provider=model_provider, + initialize_func=initialize_func, + tokenizer=tokenizer, + get_megatron_module=get_megatron_module, + **kwargs + ) + + self.update_micro_batch_size = rl_config.update_micro_batch_size + + self.reference = None + self.ref_model = None + self.ref_manager = None + + def initialize(self): + + # Based on Actor + ActorHybridWorkerBase.initialize(self) + + # Add Reference + self.ref_model = self.get_model(self.model_provider, self.model_type, wrap_with_ddp=False) + ref_model_load_path = getattr( + self.rl_config.integrated_mode_config, "ref_model_load_path", None + ) if self.rl_config.integrated_mode_config is not None else None + self.load_checkpoint_with_path(self.ref_model, ref_model_load_path, ckpt_only=True) + self.ref_manager = MegatronOffLoader(self.ref_model, wrap_with_ddp=False) + self.ref_manager.offload_param() + self.reference = Reference( + self.ref_model, + beta=self.rl_config.beta, + mini_batch_size=self.rl_config.mini_batch_size, + epochs=self.rl_config.epochs, + shuffle_mini_batch=self.rl_config.shuffle_mini_batch, + generate_config=self.generate_config, + stage=self.megatron_config.stage, + forward_backward_func=self.forward_backward_func, + micro_batch_size=self.megatron_config.micro_batch_size, + use_remove_padding=self.rl_config.use_remove_padding + ) + + def compute_ref_log_prob(self): + start_onload_time = time.time() + self.ref_manager.onload_param() + end_onload_time = time.time() + ray.get( + self.td.update_metrics.remote( + "timing/onload", + value=[round(end_onload_time, 4), round(start_onload_time, 4)], + cumulate=True + ) + ) + + ReferenceWorkerBase.compute_ref_log_prob(self) + + start_offload_time = time.time() + self.ref_manager.offload_param() + end_offload_time = time.time() + ray.get( + self.td.update_metrics.remote( + "timing/offload", + value=[round(end_offload_time, 4), round(start_offload_time, 4)], + cumulate=True + ) + ) + + def update(self, kl_ctrl=None, skip_actor_log_prob=False): + # set update mbs + update_mbs = self.update_micro_batch_size + mbs = self.actor_hybrid.train_actor.micro_batch_size + + args = self.get_args() + + if update_mbs is not None: + self.actor_hybrid.train_actor.micro_batch_size = update_mbs + args.micro_batch_size = update_mbs + + ActorHybridWorkerBase.update(self, kl_ctrl, skip_actor_log_prob) + + args.micro_batch_size = mbs + self.actor_hybrid.train_actor.micro_batch_size = mbs + + def load_checkpoint_with_path(self, model, path, ckpt_only=False): + """Load model checkpoint from a specified path with flexible control. + + Args: + model: The model to load checkpoint into. + path: Path to the checkpoint file/directory. If None, use the path in megatron args. + ckpt_only: If True, only loads model weights (skips optimizer/RNG states). + """ + + # Backup original arguments if needed + original_args = { + 'no_load_optim': getattr(self.get_args(), "no_load_optim", None), + 'no_load_rng': getattr(self.get_args(), "no_load_rng", None), + 'load': getattr(self.get_args(), "load", None), + 'iteration': getattr(self.get_args(), "iteration", None), + 'finetune': getattr(self.get_args(), "finetune", None), + 'consumed_train_samples': getattr(self.get_args(), "consumed_train_samples", None), + 'consumed_valid_samples': getattr(self.get_args(), "consumed_valid_samples", None), + } if ckpt_only or path else {} + + if ckpt_only: + self._set_args({ + "no_load_optim": True, + "no_load_rng": True, + "finetune": True, + 'consumed_train_samples': 0, + 'consumed_valid_samples': 0 + }) + + if path is not None: + self._set_args({"load": path}) + + self.load_checkpoint(model, None, None) + + if original_args: + self._set_args(original_args) + + def _set_args(self, arg_dict): + for key, value in arg_dict.items(): + if hasattr(self.get_args(), key): + setattr(self.get_args(), key, value) + + diff --git a/mindspeed_rl/workers/reference_woker.py b/mindspeed_rl/workers/reference_woker.py index d813f880491d4894552b87564cef707418f78b91..f63b35e403bfb974ecb010a35fda50e5e386aaec 100644 --- a/mindspeed_rl/workers/reference_woker.py +++ b/mindspeed_rl/workers/reference_woker.py @@ -76,7 +76,8 @@ class ReferenceWorker(BaseWorker): generate_config=self.generate_config, stage=self.megatron_config.stage, forward_backward_func=self.forward_backward_func, - micro_batch_size=self.megatron_config.micro_batch_size + micro_batch_size=self.megatron_config.micro_batch_size, + use_remove_padding=self.rl_config.use_remove_padding ) def init_transfer_dock(self, td): diff --git a/mindspeed_rl/workers/reward_woker.py b/mindspeed_rl/workers/reward_woker.py index 2eb4cd7469eb09e3bd4814861699a338d20174a9..6c381edc598b2bc12d746042765eca5ed8f09184 100644 --- a/mindspeed_rl/workers/reward_woker.py +++ b/mindspeed_rl/workers/reward_woker.py @@ -72,7 +72,8 @@ class RewardWorker(BaseWorker): beta=self.rl_config.beta, stage=self.megatron_config.stage, forward_backward_func=self.forward_backward_func, - micro_batch_size=self.megatron_config.micro_batch_size + micro_batch_size=self.megatron_config.micro_batch_size, + use_remove_padding=self.rl_config.use_remove_padding ) def init_transfer_dock(self, td):