diff --git a/configs/grpo_trainer_qwen25_32b.yaml b/configs/grpo_trainer_qwen25_32b.yaml index e92a5ad6d29a278b6c3319a481bace87bb7910da..d9d7890749e54814c4a2c8dec2afc33f894bf326 100644 --- a/configs/grpo_trainer_qwen25_32b.yaml +++ b/configs/grpo_trainer_qwen25_32b.yaml @@ -71,6 +71,7 @@ reward_config: no_load_rng: true rl_config: + guarantee_order: false blocking: false gamma: 1.0 lam: 0.95 diff --git a/configs/grpo_trainer_qwen25_7b.yaml b/configs/grpo_trainer_qwen25_7b.yaml index 41a262c692fa2936092fbc8e8981911b9defdcac..85f0cb03d2ad6c9f065807370ce63f8ef9c61b11 100644 --- a/configs/grpo_trainer_qwen25_7b.yaml +++ b/configs/grpo_trainer_qwen25_7b.yaml @@ -72,6 +72,7 @@ reward_config: no_load_rng: true rl_config: + guarantee_order: false blocking: false gamma: 1.0 lam: 0.95 diff --git a/docs/algorithms/grpo.md b/docs/algorithms/grpo.md index 24f2fab4dbbb331a059f76211da88c279145806c..4a06f68f0a2c205cdb759aeb9ae9be1be18fdfad 100644 --- a/docs/algorithms/grpo.md +++ b/docs/algorithms/grpo.md @@ -184,6 +184,7 @@ bash examples/r1/qwen25/r1_zero_qwen25_32b_worker.sh ### `rl_config:` * `blocking`:是否开启异步,默认为 False; +* `guarantee_order`: 是否开启TransferDock保序,默认 False * `n_samples_per_prompt`:每条prompt的重用次数,一条 prompt 输入能输出 n 条 responese; * `max_prompt_length`:GRPO 训练中最大 prompt 长度,默认为512; * `clip_ratio`:Actor 模型训练计算损失函数时的 clip 比例,默认为0.2 一般取值范围 [0.1,0.3] 最大取值范围[0,1] 该数值越大允许策略更新的幅度越大,反之不然; diff --git a/mindspeed_rl/config_cls/rl_config.py b/mindspeed_rl/config_cls/rl_config.py index de2ae91548ca279af227552aca264dbb0fa7775e..339b06d327f48b4a17afa518202a0796f6bff04d 100644 --- a/mindspeed_rl/config_cls/rl_config.py +++ b/mindspeed_rl/config_cls/rl_config.py @@ -83,6 +83,7 @@ class RLConfig(BaseConfig): self.wandb_exp_name = "" self.wandb_save_dir = "" self.blocking = False + self.guarantee_order = False self.num_cpus_for_local_task = 1 self.num_cpus_for_placement_group = 8 self.use_integrated_worker = False diff --git a/mindspeed_rl/models/rollout/vllm_adapter/vllm_parallel_state.py b/mindspeed_rl/models/rollout/vllm_adapter/vllm_parallel_state.py index f4c6bc80d028bceec9856086ce125770db1d0306..abcd38be61970ae4c03cd6cb2f74ec8bf029b09b 100644 --- a/mindspeed_rl/models/rollout/vllm_adapter/vllm_parallel_state.py +++ b/mindspeed_rl/models/rollout/vllm_adapter/vllm_parallel_state.py @@ -32,6 +32,13 @@ _TP = None # Pipeline model parallel group that the current rank belongs to. _PP = None +# Tensor model parallel group +_TP_GROUP_RANKS = None + + +def get_vllm_tp_group_ranks(): + return _TP_GROUP_RANKS + # This method is for initializing the ParallelGroup when using HybridEngine def initialize_parallel_state( @@ -187,9 +194,12 @@ def initialize_model_parallel_for_vllm( def get_tp_group_ranks(): if infer_tensor_model_parallel_size > train_tensor_model_parallel_size: - return get_split_tp_group_ranks() + tp_group_ranks = get_split_tp_group_ranks() else: - return get_allgather_tp_group_ranks() + tp_group_ranks = get_allgather_tp_group_ranks() + global _TP_GROUP_RANKS + _TP_GROUP_RANKS = tp_group_ranks + return tp_group_ranks _TP = init_model_parallel_group( diff --git a/mindspeed_rl/trainer/base.py b/mindspeed_rl/trainer/base.py index 7b59228f9db98622ad0fab6753c2475f21875877..4ad0b2c2f8f96865c9bda9ed6140d5efc589134d 100644 --- a/mindspeed_rl/trainer/base.py +++ b/mindspeed_rl/trainer/base.py @@ -37,6 +37,7 @@ class RayBaseTrainer(object): tokenizer: BaseTokenizer = None, dataset_additional_keys: List[str] = None, blocking: bool = False, + guarantee_order: bool = False, num_cpus_for_local_task: float = 0.1, **kwargs): @@ -60,6 +61,7 @@ class RayBaseTrainer(object): self.tokenizer = tokenizer self.dataset_additional_keys = dataset_additional_keys self.blocking = blocking + self.guarantee_order = guarantee_order self.num_cpus_for_local_task = num_cpus_for_local_task self.kwargs = kwargs diff --git a/mindspeed_rl/trainer/grpo_trainer_hybrid.py b/mindspeed_rl/trainer/grpo_trainer_hybrid.py index e361fcd6ce6c7b714a43c0973654e4c8d56f2d37..e9f1d3a9f96601347292ae657aa4e1b57649dcfd 100644 --- a/mindspeed_rl/trainer/grpo_trainer_hybrid.py +++ b/mindspeed_rl/trainer/grpo_trainer_hybrid.py @@ -61,6 +61,7 @@ class RayGRPOTrainer(RayBaseTrainer): tokenizer: BaseTokenizer = None, dataset_additional_keys: List[str] = None, blocking: bool = False, + guarantee_order: bool = False, num_cpus_for_local_task: int = 1, **kwargs ): @@ -81,6 +82,7 @@ class RayGRPOTrainer(RayBaseTrainer): tokenizer=tokenizer, dataset_additional_keys=dataset_additional_keys, blocking=blocking, + guarantee_order=guarantee_order, num_cpus_for_local_task=num_cpus_for_local_task, **kwargs ) @@ -135,7 +137,7 @@ class RayGRPOTrainer(RayBaseTrainer): self.rule_reward_compute_rm_score(reward_worker, blocking=False) # compute advantages, executed on the driver process - self.compute_advantage(blocking=False) + self.compute_advantage(blocking=False, guarantee_order=self.guarantee_order) # compute reference log_prob self.ref_worker.compute_ref_log_prob(blocking=self.blocking) @@ -156,7 +158,9 @@ class RayGRPOTrainer(RayBaseTrainer): # collect metrics grpo_data_metrics = compute_grpo_data_metrics(self.transfer_dock, self.global_batch_size * self.n_samples_per_prompt, - self.tokenizer) + self.tokenizer, + self.global_batch_size * self.n_samples_per_prompt, + self.guarantee_order) metrics_result = ray.get(self.transfer_dock.get_metrics.remote()) metrics_result = metrics_post_processing(metrics_result) @@ -179,7 +183,7 @@ class RayGRPOTrainer(RayBaseTrainer): logger.info('after grpo training is done') - def compute_advantage(self, blocking=False): + def compute_advantage(self, blocking=False, guarantee_order=False): experience_count = get_least_common_multiple(self.micro_batch_size, self.n_samples_per_prompt) @@ -191,6 +195,8 @@ class RayGRPOTrainer(RayBaseTrainer): adv_estimator=self.adv_estimator, experience_count=experience_count, tokenizer=self.tokenizer, + global_batch_size=self.global_batch_size * self.n_samples_per_prompt, + guarantee_order=guarantee_order ) if blocking: ray.get(compute_advantage_ref) diff --git a/mindspeed_rl/trainer/utils/compute_utils.py b/mindspeed_rl/trainer/utils/compute_utils.py index ea1033468ad8e11b4702d80909fa2e5d84290bbb..bf51679d12e69d692f8fd02054d3613414d1cf78 100644 --- a/mindspeed_rl/trainer/utils/compute_utils.py +++ b/mindspeed_rl/trainer/utils/compute_utils.py @@ -23,7 +23,7 @@ import numpy as np import mindspeed_rl.utils.torch_functional as F from mindspeed_rl.utils.pad_process import truncate_rows -from mindspeed_rl.utils.utils import generate_mask +from mindspeed_rl.utils.utils import generate_mask, get_current_dp_range_indexes from mindspeed_rl.trainer.utils.transfer_dock import pad_experience @@ -125,7 +125,7 @@ def compute_group_norm_advantage_return(token_level_rewards: torch.Tensor, eos_m @ray.remote -def compute_advantage(td, gamma, lam, adv_estimator, experience_count, tokenizer): +def compute_advantage(td, gamma, lam, adv_estimator, experience_count, tokenizer, global_batch_size, guarantee_order): """ Compute the advantage function based on different adv_estimator @@ -136,6 +136,8 @@ def compute_advantage(td, gamma, lam, adv_estimator, experience_count, tokenizer adv_estimator: The type of advantage estimator, which can be "gae" or "group_norm" experience_count: The number of experiences to retrieve from the experience td tokenizer: The pre-trained tokenizer + global_batch_size: The number of global batch size + guarantee_order: The switch of guarantee order Returns: None @@ -143,11 +145,13 @@ def compute_advantage(td, gamma, lam, adv_estimator, experience_count, tokenizer experience_consumer_stage = "compute_advantage" experience_columns = ["responses", "token_level_rewards", "response_length"] pad_token_id = tokenizer.pad if tokenizer.pad is not None else tokenizer.eod - + sorted_indexes = get_current_dp_range_indexes(experience_count=experience_count, + assign_batch_size=global_batch_size) if guarantee_order else None while not ray.get(td.all_consumed.remote(experience_consumer_stage)): batch_data, index = ray.get( td.get_experience.remote( experience_consumer_stage, experience_columns, experience_count, # pad_id=pad_token_id + indexes=sorted_indexes.pop(0) if guarantee_order else None ) ) if batch_data and index: @@ -193,7 +197,7 @@ def get_last_reward(rm_scores, n_sample_batch: int): def compute_grpo_data_metrics( - td, experience_count, tokenizer + td, experience_count, tokenizer, global_batch_size, guarantee_order ): """ Calculate various metrics for GRPO data @@ -202,6 +206,8 @@ def compute_grpo_data_metrics( td: A data queue object experience_count: Number of experiences to retrieve tokenizer: The pre-trained tokenizer + global_batch_size: The number of global batch size + guarantee_order: The switch of guarantee order Returns: Dictionary containing various metric values @@ -217,9 +223,12 @@ def compute_grpo_data_metrics( "response_length", ] pad_token_id = tokenizer.pad if tokenizer.pad is not None else tokenizer.eod + sorted_indexes = get_current_dp_range_indexes(experience_count=experience_count, + assign_batch_size=global_batch_size) if guarantee_order else None while not ray.get(td.all_consumed.remote(experience_consumer_stage)): batch, index = ray.get( - td.get_experience.remote(experience_consumer_stage, experience_columns, experience_count) + td.get_experience.remote(experience_consumer_stage, experience_columns, experience_count, + indexes=sorted_indexes.pop(0) if guarantee_order else None) ) if batch and index: batch = pad_experience(batch, pad_token_id) # multiple, tp_size diff --git a/mindspeed_rl/utils/utils.py b/mindspeed_rl/utils/utils.py index 5293d9e1621d17da418db6f958f6f30dff8b2eb9..8526cc83dd610b326ccf1461e9f286f97d81d99d 100644 --- a/mindspeed_rl/utils/utils.py +++ b/mindspeed_rl/utils/utils.py @@ -16,6 +16,11 @@ import torch_npu from torch import Tensor +def get_current_dp_range_indexes(experience_count, assign_batch_size, current_dp_rank=0): + all_indexes = list(range(assign_batch_size * current_dp_rank, assign_batch_size * (current_dp_rank + 1))) + return [all_indexes[i:i + experience_count] for i in range(0, len(all_indexes), experience_count)] + + def synchronize_time(): """Synchronize training start time across all distributed processes.""" cur_time = time.time() @@ -272,6 +277,11 @@ def compute_vllm_throughput(compute_kwargs, metrics_result, gbs, n_samples, time def seed_all(seed=1234): random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) + os.environ['HCCL_DETERMINISTIC'] = str(True) + os.environ['LCCL_DETERMINISTIC'] = str(1) + os.environ['CLOSE_MATMUL_K_SHIFT'] = str(1) + os.environ['ATB_MATMUL_SHUFFLE_K_ENABLE'] = "0" + os.environ['ATB_LLM_LCOC_ENABLE'] = "0" np.random.seed(seed) torch.manual_seed(seed) torch.use_deterministic_algorithms(True) diff --git a/mindspeed_rl/workers/actor_hybrid_worker.py b/mindspeed_rl/workers/actor_hybrid_worker.py index 51d69beb853f21e7eb3dba1f0655a5ebbfcd8fd0..5055f19457e5ecb0fc49a0c06419b5174ab764f2 100644 --- a/mindspeed_rl/workers/actor_hybrid_worker.py +++ b/mindspeed_rl/workers/actor_hybrid_worker.py @@ -121,7 +121,7 @@ class ActorHybridWorkerBase(BaseWorker): experience_columns = ['responses', 'advantages', 'old_log_prob', 'ref_log_prob', 'input_ids', 'response_length', 'prompt_length'] - + experience_count = self.megatron_config.global_batch_size // self.parallel_state.get_data_parallel_world_size() #get lr @@ -129,14 +129,17 @@ class ActorHybridWorkerBase(BaseWorker): for param_group in self.optimizer.param_groups: learning_rate = param_group['lr'] ray.get(self.td.update_metrics.remote(key='grpo/lr', value=learning_rate)) - + sorted_indexes = self.get_dp_range_indexes(experience_count, + use_vllm=False) if self.rl_config.guarantee_order else None start_time_defined = False count = 0 - while self.all_consumed(experience_consumer_stage) > 0: + while self.all_consumed(experience_consumer_stage, sorted_indexes) > 0: batch_data, index = self.dispatch_transfer_dock_data(experience_consumer_stage, experience_columns, experience_count, self.megatron_config.tensor_model_parallel_size, + indexes=sorted_indexes.pop( + 0) if self.rl_config.guarantee_order else None, get_n_samples=False) if not start_time_defined: start_time = time.time() @@ -185,14 +188,17 @@ class ActorHybridWorkerBase(BaseWorker): self.rl_config.n_samples_per_prompt) pad_token_id = self.tokenizer.pad if self.tokenizer.pad else self.tokenizer.eod + sorted_indexes = self.get_dp_range_indexes(experience_count, + use_vllm=True) if self.rl_config.guarantee_order else None start_time_defined = False - while self.all_consumed(experience_consumer_stage, use_vllm=True) > 0: + while self.all_consumed(experience_consumer_stage, sorted_indexes, use_vllm=True) > 0: batch_data, index = self.dispatch_transfer_dock_data( experience_consumer_stage, experience_columns, experience_count, tp_size=self.megatron_config.tensor_model_parallel_size, + indexes=sorted_indexes.pop(0) if self.rl_config.guarantee_order else None, use_vllm=True ) if not start_time_defined: @@ -254,13 +260,17 @@ class ActorHybridWorkerBase(BaseWorker): experience_columns = ['input_ids', 'responses', 'response_length', 'prompt_length'] experience_count = get_least_common_multiple(self.megatron_config.micro_batch_size, self.rl_config.n_samples_per_prompt) + sorted_indexes = self.get_dp_range_indexes(experience_count, + use_vllm=False) if self.rl_config.guarantee_order else None start_time_defined = False - while self.all_consumed(experience_consumer_stage) > 0: + while self.all_consumed(experience_consumer_stage, sorted_indexes) > 0: batch_data, index = self.dispatch_transfer_dock_data(experience_consumer_stage, experience_columns, experience_count, tp_size=self.megatron_config.tensor_model_parallel_size, + indexes=sorted_indexes.pop( + 0) if self.rl_config.guarantee_order else None, get_n_samples=False) if not start_time_defined: start_time = time.time() diff --git a/mindspeed_rl/workers/base_worker.py b/mindspeed_rl/workers/base_worker.py index eaa9bab6f90f269ea84356268a9c3877fc01e0af..3470a1912531ed2247e51077e3f83f4025f8bc62 100644 --- a/mindspeed_rl/workers/base_worker.py +++ b/mindspeed_rl/workers/base_worker.py @@ -11,6 +11,7 @@ import torch import torch_npu import ray +from mindspeed_rl.models.rollout.vllm_adapter.vllm_parallel_state import get_vllm_tp_group_ranks from mindspeed_rl.utils.loggers import Loggers from mindspeed_rl.utils.tokenizer import BaseTokenizer @@ -28,10 +29,15 @@ from mindspeed_rl.trainer.utils.parallel_state import ( get_model_parallel_group ) from mindspeed_rl.utils.compute import set_parallel_state, set_vocab_parallel +from mindspeed_rl.utils.utils import get_current_dp_range_indexes from mindspeed_rl.trainer.utils.transfer_dock import pack_experience_columns, unpack_pad_experience logger = Loggers("base_worker") +_DP_RANGE_DATA_CONSUMED_FLAG = 0 + +_DP_RANGE_DATA_NOT_CONSUMED_FLAG = 1 + class BaseRayWorker: def __init__(self): @@ -145,7 +151,11 @@ class BaseWorker(BaseRayWorker, ABC): self.td = None self.args = None - def all_consumed(self, experience_consumer_stage, use_vllm=False): + def all_consumed(self, experience_consumer_stage, sorted_indexes, use_vllm=False): + if self.rl_config.guarantee_order and not sorted_indexes: + return _DP_RANGE_DATA_CONSUMED_FLAG + elif self.rl_config.guarantee_order: + return _DP_RANGE_DATA_NOT_CONSUMED_FLAG if use_vllm: current_device = next(self.inference_model.model.parameters()).device else: @@ -205,7 +215,7 @@ class BaseWorker(BaseRayWorker, ABC): def dispatch_transfer_dock_data(self, experience_consumer_stage, experience_columns, experience_count, tp_size=1, - use_vllm=False, + use_vllm=False, indexes=None, get_n_samples=True): pad_id = self.tokenizer.pad if self.tokenizer.pad else self.tokenizer.eod @@ -216,8 +226,8 @@ class BaseWorker(BaseRayWorker, ABC): if get_tensor_model_parallel_rank(self.parallel_state, use_vllm) == 0 and \ get_pipeline_model_parallel_rank(self.parallel_state, use_vllm) == 0: batch_data, index = ray.get(self.td.get_experience.remote(experience_consumer_stage, experience_columns, - experience_count, get_n_samples=get_n_samples)) # cpu数据 - + experience_count, indexes=indexes, + get_n_samples=get_n_samples)) # cpu数据 if not index: # 判断是否取出数据,未取出数据为-1 index = [-1] * experience_count @@ -325,3 +335,27 @@ class BaseWorker(BaseRayWorker, ABC): use_vllm) == 0: output = {key: value.cpu() if not isinstance(value, List) else value for key, value in output.items()} self.td.put_experience.remote(data_dict=output, indexes=index) + + def get_dp_range_indexes(self, experience_count, use_vllm=False): + if use_vllm: + current_dp_rank, dp_world_size = self.get_vllm_dp_rank() + else: + current_dp_rank = self.parallel_state.get_data_parallel_rank() + dp_world_size = self.parallel_state.get_data_parallel_world_size() + assign_batch_size = self.megatron_config.global_batch_size // dp_world_size + return get_current_dp_range_indexes(experience_count=experience_count, + assign_batch_size=assign_batch_size, + current_dp_rank=current_dp_rank) + + @staticmethod + def get_vllm_dp_rank(): + get_rollout_data_parallel_rank = torch.distributed.get_rank() + vllm_dp_groups = get_vllm_tp_group_ranks() + if vllm_dp_groups is None: + raise ValueError("vllm dp groups is None") + for index, dp_group in enumerate(vllm_dp_groups): + if get_rollout_data_parallel_rank in dp_group: + current_dp_rank = index + return current_dp_rank, len(vllm_dp_groups) + + diff --git a/mindspeed_rl/workers/reference_woker.py b/mindspeed_rl/workers/reference_woker.py index 61f9bebe782b15c05387bf3dacff042f46be64ee..7fd06586cc3e6dc3b33e1bd085540e130c7e92e8 100644 --- a/mindspeed_rl/workers/reference_woker.py +++ b/mindspeed_rl/workers/reference_woker.py @@ -87,13 +87,17 @@ class ReferenceWorkerBase(BaseWorker): experience_columns = ['input_ids', 'responses', 'response_length', 'prompt_length'] experience_count = get_least_common_multiple(self.megatron_config.micro_batch_size, self.rl_config.n_samples_per_prompt) + sorted_indexes = self.get_dp_range_indexes(experience_count, + use_vllm=False) if self.rl_config.guarantee_order else None start_time_defined = False - while self.all_consumed(experience_consumer_stage) > 0: + while self.all_consumed(experience_consumer_stage, sorted_indexes) > 0: batch_data, index = self.dispatch_transfer_dock_data(experience_consumer_stage, experience_columns, experience_count, tp_size=self.megatron_config.tensor_model_parallel_size, + indexes=sorted_indexes.pop( + 0) if self.rl_config.guarantee_order else None, get_n_samples=False) if not start_time_defined: diff --git a/mindspeed_rl/workers/reward_woker.py b/mindspeed_rl/workers/reward_woker.py index fccdf609484a8049d1f094332ae0d05fe1ad983f..a556904d3a340fd664c92642563ad6770f71fe8c 100644 --- a/mindspeed_rl/workers/reward_woker.py +++ b/mindspeed_rl/workers/reward_woker.py @@ -84,13 +84,18 @@ class RewardWorkerBase(BaseWorker): *self.megatron_config.dataset_additional_keys] experience_count = get_least_common_multiple(self.megatron_config.micro_batch_size, self.rl_config.n_samples_per_prompt) + sorted_indexes = self.get_dp_range_indexes(experience_count, + use_vllm=False) if self.rl_config.guarantee_order else None start_time_defined = False - while not ray.get(self.td.all_consumed.remote(experience_consumer_stage)): + while self.all_consumed(experience_consumer_stage, sorted_indexes) > 0: batch_data, index = self.dispatch_transfer_dock_data(experience_consumer_stage, experience_columns, experience_count, - tp_size=self.megatron_config.tensor_model_parallel_size) + tp_size=self.megatron_config.tensor_model_parallel_size, + indexes=sorted_indexes.pop( + 0) if self.rl_config.guarantee_order else None, + ) if not start_time_defined: start_time = time.time() start_time_defined = True diff --git a/mindspeed_rl/workers/rule_reward.py b/mindspeed_rl/workers/rule_reward.py index 47c32f0d18c3fd4984523aec96e9147a5dc9732a..80778c37ce5dab46761f5c72711bea7f4b79ecc3 100644 --- a/mindspeed_rl/workers/rule_reward.py +++ b/mindspeed_rl/workers/rule_reward.py @@ -3,8 +3,8 @@ import ray from mindspeed_rl.models.rule_verifier import compute_verifier_score from mindspeed_rl.utils.loggers import Loggers -from mindspeed_rl.utils.utils import get_least_common_multiple from mindspeed_rl.trainer.utils.transfer_dock import pad_experience +from mindspeed_rl.utils.utils import get_least_common_multiple, get_current_dp_range_indexes logger = Loggers("rule_reward") @@ -26,13 +26,17 @@ class RuleReward(object): experience_columns = ['prompts', 'responses', 'response_length', *self.megatron_config.dataset_additional_keys] experience_count = get_least_common_multiple(self.megatron_config.micro_batch_size, self.rl_config.n_samples_per_prompt) + assign_batch_size = self.megatron_config.global_batch_size * self.rl_config.n_samples_per_prompt + sorted_indexes = get_current_dp_range_indexes(experience_count=experience_count, + assign_batch_size=assign_batch_size) if self.rl_config.guarantee_order else None + pad_token_id = self.tokenizer.pad if self.tokenizer.pad else self.tokenizer.eod while not ray.get(self.td.all_consumed.remote(experience_consumer_stage)): batch_data, index = ray.get( self.td.get_experience.remote( experience_consumer_stage, experience_columns, - experience_count, + experience_count ) ) # cpu数据