From fa7aa8010410a06de8431730cd3fd932ec018a06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E6=97=AD=E6=97=B6?= Date: Mon, 21 Apr 2025 16:10:36 +0800 Subject: [PATCH 01/10] =?UTF-8?q?=E6=8F=90=E4=BA=A4partial=5Frollout?= =?UTF-8?q?=E4=B8=BB=E6=8E=A7=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mindspeed_rl/config_cls/generate_config.py | 16 +++ mindspeed_rl/config_cls/validate_config.py | 7 ++ mindspeed_rl/models/actor_rollout_hybrid.py | 4 +- mindspeed_rl/models/rollout/vllm_engine.py | 2 - mindspeed_rl/trainer/base.py | 4 + mindspeed_rl/trainer/grpo_trainer_hybrid.py | 27 ++++- mindspeed_rl/trainer/utils/compute_utils.py | 1 + mindspeed_rl/workers/actor_hybrid_worker.py | 109 +++++++++++++++++++- mindspeed_rl/workers/base_worker.py | 18 ++-- mindspeed_rl/workers/reward_woker.py | 4 +- 10 files changed, 168 insertions(+), 24 deletions(-) diff --git a/mindspeed_rl/config_cls/generate_config.py b/mindspeed_rl/config_cls/generate_config.py index 7ccc61cf..a5e046bd 100644 --- a/mindspeed_rl/config_cls/generate_config.py +++ b/mindspeed_rl/config_cls/generate_config.py @@ -46,6 +46,8 @@ class GenerateConfig(BaseConfig): # 推理时的微批次处理数据 self.micro_batch_size = None + + self.partial_micro_batch_size = None # 推理时的张量并行大小,默认为 8 self.infer_tensor_parallel_size = 8 @@ -76,6 +78,13 @@ class GenerateConfig(BaseConfig): self.enable_prefix_caching = False self.num_scheduler_steps = 1 + # 最大的partial rollout次数 + self.max_partial_times = None + # 最大的partial rollout长度 + self.max_seq_len = None + # 样本分布,k越大代表占据分布越大 + self.k = None + # 采样配置的默认值,用于生成文本时的采样策略设置 self.sampling_config = { "logprobs": 1, # 返回的 top token 的对数概率数量 @@ -90,3 +99,10 @@ class GenerateConfig(BaseConfig): # 如果提供了配置字典,则更新默认值 if config_dict is not None: self.update(config_dict) + + @property + def enable_partial_rollout(self): + enable_partial_rollout = False + if self.max_seq_len and self.max_partial_times: + enable_partial_rollout = True + return enable_partial_rollout diff --git a/mindspeed_rl/config_cls/validate_config.py b/mindspeed_rl/config_cls/validate_config.py index 09c8748e..d1326e09 100644 --- a/mindspeed_rl/config_cls/validate_config.py +++ b/mindspeed_rl/config_cls/validate_config.py @@ -31,6 +31,13 @@ def validate_rl_args( f"Actor.seq_length={actor_config.seq_length} vs " f"GenerateConfig.max_model_len={generate_config.max_model_len}") + # 校验partial_rollout参数合理性 + if generate_config.enable_partial_rollout: + if generate_config.max_seq_len % generate_config.max_partial_times != 0: + raise ValueError( + f"max_seq_len {generate_config.max_seq_len} " + f"must be divisible by max_partial_times {generate_config.max_partial_times}") + # 校验资源分配合理性 def _validate_resource(resource, t_size, p_size, c_size, component): product = t_size * p_size * c_size diff --git a/mindspeed_rl/models/actor_rollout_hybrid.py b/mindspeed_rl/models/actor_rollout_hybrid.py index 7f66cff9..df1bd8fe 100644 --- a/mindspeed_rl/models/actor_rollout_hybrid.py +++ b/mindspeed_rl/models/actor_rollout_hybrid.py @@ -64,8 +64,8 @@ class ActorRolloutHybrid(ABC): self.inference_actor = inference_model self.sharding_manager = sharding_manager - def generate_sequences(self, prompts_list: List[List[int]]) -> Tensor: - responses = self.inference_actor.generate_sequences(prompts_list)[0] + def generate_sequences(self, prompts_list: List[List[int]], **kwargs) -> Tensor: + responses = self.inference_actor.generate_sequences(prompts_list, **kwargs)[0] return responses def compute_log_prob(self, data: Dict) -> Tensor: diff --git a/mindspeed_rl/models/rollout/vllm_engine.py b/mindspeed_rl/models/rollout/vllm_engine.py index 16ce4c31..f03f7ed5 100644 --- a/mindspeed_rl/models/rollout/vllm_engine.py +++ b/mindspeed_rl/models/rollout/vllm_engine.py @@ -266,8 +266,6 @@ class VLLMInferEngine(BaseInferEngine): logprob.append(logprobs_dict[token_id].logprob) logprobs.append(torch.tensor(logprob)) - output_token_ids = pad_sequence(output_token_ids, batch_first=True, - padding_value=self.pad_token_id) if len(logprobs) > 0: logprobs = pad_sequence(logprobs, batch_first=True, padding_value=self.pad_token_id) diff --git a/mindspeed_rl/trainer/base.py b/mindspeed_rl/trainer/base.py index 7b59228f..2ba0d65a 100644 --- a/mindspeed_rl/trainer/base.py +++ b/mindspeed_rl/trainer/base.py @@ -38,6 +38,8 @@ class RayBaseTrainer(object): dataset_additional_keys: List[str] = None, blocking: bool = False, num_cpus_for_local_task: float = 0.1, + max_partial_times: int = None, + max_seq_len: int = None, **kwargs): self.actor_worker = actor_worker @@ -61,6 +63,8 @@ class RayBaseTrainer(object): self.dataset_additional_keys = dataset_additional_keys self.blocking = blocking self.num_cpus_for_local_task = num_cpus_for_local_task + self.max_partial_times = max_partial_times + self.max_seq_len = max_seq_len self.kwargs = kwargs # define KL control diff --git a/mindspeed_rl/trainer/grpo_trainer_hybrid.py b/mindspeed_rl/trainer/grpo_trainer_hybrid.py index d2aef2b4..7315fb5b 100644 --- a/mindspeed_rl/trainer/grpo_trainer_hybrid.py +++ b/mindspeed_rl/trainer/grpo_trainer_hybrid.py @@ -60,6 +60,8 @@ class RayGRPOTrainer(RayBaseTrainer): dataset_additional_keys: List[str] = None, blocking: bool = False, num_cpus_for_local_task: int = 1, + max_partial_times: int = None, + max_seq_len: int = None, **kwargs ): super().__init__( @@ -80,6 +82,8 @@ class RayGRPOTrainer(RayBaseTrainer): dataset_additional_keys=dataset_additional_keys, blocking=blocking, num_cpus_for_local_task=num_cpus_for_local_task, + max_partial_times=max_partial_times, + max_seq_len=max_seq_len, **kwargs ) @@ -89,7 +93,7 @@ class RayGRPOTrainer(RayBaseTrainer): self.kwargs = kwargs def transfer_dock_init(self): - self.transfer_dock = GRPOTransferDock.remote(self.global_batch_size, self.n_samples_per_prompt, + self.transfer_dock = GRPOTransferDock.remote(self.global_batch_size, self.n_samples_per_prompt, self.max_seq_len, self.max_partial_times, self.metrics, addition_columns=self.dataset_additional_keys) self.actor_worker.sync_init_transfer_dock(self.transfer_dock) self.ref_worker.sync_init_transfer_dock(self.transfer_dock) @@ -114,7 +118,13 @@ class RayGRPOTrainer(RayBaseTrainer): logger.info('sync start grpo training at iteration: {}/{} ...'.format(iteration, self.train_iters)) else: logger.info('async start grpo training at iteration: {}/{} ...'.format(iteration, self.train_iters)) - + + enable_partial_rollout = self.max_partial_times and self.max_seq_len + if enable_partial_rollout: + global_response_std = [] + global_response_mean = [] + global_num_partial_rollout = [] + global_num_partial_rollout_len = 0 while iteration < self.train_iters: ray.get(self.transfer_dock.clear.remote()) @@ -123,6 +133,11 @@ class RayGRPOTrainer(RayBaseTrainer): with Timer(name='iteration', logger=None) as all_timer: # generate sequences + if enable_partial_rollout: + ray.get(self.transfer_dock.update_metrics.remote("response_length/std", global_response_std, cumulate=True)) + ray.get(self.transfer_dock.update_metrics.remote("response_length/mean", global_response_mean, cumulate=True)) + ray.get(self.transfer_dock.update_metrics.remote('total_partial_rollout_indexes', global_num_partial_rollout, cumulate=True)) + global_num_partial_rollout_len = len(global_num_partial_rollout) self.actor_worker.generate_sequences(blocking=self.blocking) # compute rm scores. @@ -155,7 +170,13 @@ class RayGRPOTrainer(RayBaseTrainer): self.global_batch_size * self.n_samples_per_prompt, self.tokenizer) metrics_result = ray.get(self.transfer_dock.get_metrics.remote()) - + if enable_partial_rollout: + global_response_std = [grpo_data_metrics['response_length/std']] + global_response_mean = [grpo_data_metrics[ 'response_length/mean']] + metrics_result.metric["response_length/std"] = global_response_std + metrics_result.metric["response_length/mean"] = global_response_mean + local_num_partial_rollout = metrics_result.metric['total_partial_rollout_indexes'] + global_num_partial_rollout = local_num_partial_rollout[global_num_partial_rollout_len:] metrics_result = metrics_post_processing(metrics_result) metrics_result = metrics_sort(metrics_result, all_timer.last) tps = compute_tps(self.kwargs, grpo_data_metrics, self.global_batch_size, self.n_samples_per_prompt, all_timer.last) diff --git a/mindspeed_rl/trainer/utils/compute_utils.py b/mindspeed_rl/trainer/utils/compute_utils.py index e096573e..112070d6 100644 --- a/mindspeed_rl/trainer/utils/compute_utils.py +++ b/mindspeed_rl/trainer/utils/compute_utils.py @@ -236,6 +236,7 @@ def compute_grpo_data_metrics( "response_length/mean": torch.mean(response_length, dtype=torch.float32).detach().item(), "response_length/max": torch.max(response_length).detach().item(), "response_length/min": torch.min(response_length).detach().item(), + "response_length/std": torch.std(response_length.to(torch.float32)).detach().item(), # prompt length "prompt_length/mean": torch.mean(prompt_length, dtype=torch.float32).detach().item(), "prompt_length/max": torch.max(prompt_length).detach().item(), diff --git a/mindspeed_rl/workers/actor_hybrid_worker.py b/mindspeed_rl/workers/actor_hybrid_worker.py index 902cb273..5079ccef 100644 --- a/mindspeed_rl/workers/actor_hybrid_worker.py +++ b/mindspeed_rl/workers/actor_hybrid_worker.py @@ -5,6 +5,7 @@ import dataclasses import copy from typing import Callable +import math import ray from torch import nn import torch @@ -159,6 +160,14 @@ class ActorHybridWorkerBase(BaseWorker): self.num_floating_point_operations_so_far) def generate_sequences(self): + if self.generate_config.enable_partial_rollout: + metrics_result = ray.get(self.td.get_metrics.remote()) + pre_iter_avg_seq_length = metrics_result.metric['response_length/mean'] + pre_iter_std_seq_length = metrics_result.metric['response_length/std'] + max_tokens_first_stage = None + if pre_iter_avg_seq_length and pre_iter_std_seq_length: + k = self.generate_config.k + max_tokens_first_stage = round(pre_iter_avg_seq_length[0] + pre_iter_std_seq_length[0] * k) self.sharding_manager.enter_infer_mode() experience_consumer_stage = 'actor_rollout' @@ -169,13 +178,14 @@ class ActorHybridWorkerBase(BaseWorker): pad_token_id = self.tokenizer.pad if self.tokenizer.pad else self.tokenizer.eod start_time_defined = False + total_partial_rollout_indexes = [] while self.all_consumed(experience_consumer_stage, use_vllm=True) > 0: batch_data, index = self.dispatch_transfer_dock_data( experience_consumer_stage, experience_columns, experience_count, - tp_size=self.megatron_config.tensor_model_parallel_size, - use_vllm=True + tp_size=self.generate_config.infer_tensor_parallel_size, + use_vllm=True, ) if not start_time_defined: start_time = time.time() @@ -188,8 +198,18 @@ class ActorHybridWorkerBase(BaseWorker): prompts = truncate_rows(prompts_data, prompt_length_data) prompts_list = [prompt.numpy().tolist() for prompt in prompts] - responses_pad_right = self.actor_hybrid.generate_sequences(copy.deepcopy(prompts_list)) - responses = remove_padding_and_split_to_list(responses_pad_right, self.tokenizer.eod, pad_token_id) + # inference + if self.generate_config.enable_partial_rollout and pre_iter_avg_seq_length and pre_iter_std_seq_length and max_tokens_first_stage: + responses = self.actor_hybrid.generate_sequences(copy.deepcopy(prompts_list), max_tokens=max_tokens_first_stage) + else: + responses = self.actor_hybrid.generate_sequences(copy.deepcopy(prompts_list)) + + partial_rollout_indexes = [] + if self.generate_config.enable_partial_rollout: + for idx, response in enumerate(responses): + if response[-1] != self.tokenizer.eod: + partial_rollout_indexes.append(idx) + total_partial_rollout_indexes.append(index[idx]) responses_length = [torch.tensor([len(response)]) for response in responses] # copy prompts (from 1 to n_samples_per_prompt) @@ -208,7 +228,7 @@ class ActorHybridWorkerBase(BaseWorker): 'input_ids': input_ids_list, 'response_length': responses_length } - self.collect_transfer_dock_data(outputs, index, use_vllm=True) + self.collect_transfer_dock_data(outputs, index, partial_rollout_indexes, use_vllm=True) end_time = time.time() ray.get( self.td.update_metrics.remote( @@ -218,6 +238,77 @@ class ActorHybridWorkerBase(BaseWorker): ) ) + if self.generate_config.enable_partial_rollout: + metrics_result = ray.get(self.td.get_metrics.remote()) + pre_total_partial_rollout_indexes = metrics_result.metric['total_partial_rollout_indexes'] + + if pre_total_partial_rollout_indexes: + dp_num = self.generate_config.data_parallel_size + experience_count = math.ceil(len(pre_total_partial_rollout_indexes)/dp_num) + else: + experience_count = self.generate_config.partial_micro_batch_size + + while self.all_consumed('actor_partial_rollout', use_vllm=True) > 0: + batch_data, index = self.dispatch_transfer_dock_data( + 'actor_partial_rollout', + ['prompts', 'responses', 'prompt_length'], + experience_count, + tp_size=self.generate_config.infer_tensor_parallel_size, + use_vllm=True, + get_n_samples=False, + target_seq_len=pre_iter_avg_seq_length + ) + if batch_data and index: + prompts = batch_data['prompts'] + prompts_list = remove_padding_and_split_to_list(prompts, self.tokenizer.eod, pad_token_id, to_list=True) + first_stage_resps = remove_padding_and_split_to_list(batch_data['responses'], self.tokenizer.eod, pad_token_id) + + prompt_length = batch_data['prompt_length'] + if pre_iter_avg_seq_length and pre_iter_std_seq_length and max_tokens_first_stage: + max_tokens = (self.generate_config.max_seq_len - max_tokens_first_stage) // self.generate_config.max_partial_times + self.print0(f"[tmp log][maxt partial tokens second stage]{max_tokens}") + else: + max_tokens = (self.generate_config.max_seq_len - self.generate_config.sampling_config["max_tokens"]) // self.generate_config.max_partial_times + self.print0(f"[tmp log][maxt partial tokens second stage]{max_tokens}") + + second_satge_resps = self.actor_hybrid.generate_sequences( + copy.deepcopy(prompts_list), + max_tokens=max_tokens, + n=1, + ) + partial_rollout_indexes = [] + for idx, response in enumerate(second_satge_resps): + if response[-1] != self.tokenizer.eod: + partial_rollout_indexes.append(idx) + + partial_rollout_merged_responses = [] + for first_resp, second_resp in zip(first_stage_resps, second_satge_resps): + if first_resp[-1] == self.tokenizer.eod: + first_resp = first_resp[:-1] + partial_rollout_merged_responses.append(torch.cat((first_resp.cpu(), second_resp))) + + responses_length = [torch.tensor([len(resp)]) for resp in partial_rollout_merged_responses] + + prompts = [copy.deepcopy(p[:pl.item()]) for p, pl in zip(prompts, prompt_length)] + + input_ids_list = [] + for prompt, response in zip(prompts, partial_rollout_merged_responses): + input_ids_list.append(torch.cat((prompt.cpu(), response), dim=0)) + + outputs = { + 'responses': partial_rollout_merged_responses, + 'input_ids': input_ids_list, + 'response_length': responses_length + } + ray.get( + self.td.update_metrics.remote( + "timing/rollout", + value=[round(time.time(), 4), round(start_time, 4)], + cumulate=True + ) + ) + self.collect_transfer_dock_data(outputs, index, partial_rollout_indexes, use_vllm=True) + from datetime import datetime generate_end_time = time.time() parallel_state = get_parallel_state() use_vllm = True @@ -229,6 +320,14 @@ class ActorHybridWorkerBase(BaseWorker): cumulate=True ) ) + if self.generate_config.enable_partial_rollout: + ray.get( + self.td.update_metrics.remote( + "total_partial_rollout_indexes", + value=total_partial_rollout_indexes, + cumulate=True + ) + ) self.sharding_manager.exit_infer_mode() diff --git a/mindspeed_rl/workers/base_worker.py b/mindspeed_rl/workers/base_worker.py index 72169fb5..11e39a93 100644 --- a/mindspeed_rl/workers/base_worker.py +++ b/mindspeed_rl/workers/base_worker.py @@ -205,7 +205,8 @@ class BaseWorker(BaseRayWorker, ABC): def dispatch_transfer_dock_data(self, experience_consumer_stage, experience_columns, experience_count, tp_size=1, use_vllm=False, - get_n_samples=True): + get_n_samples=True, + target_seq_len=None): pad_id = self.tokenizer.pad if self.tokenizer.pad else self.tokenizer.eod batch_data = {} @@ -216,11 +217,12 @@ class BaseWorker(BaseRayWorker, ABC): batch_data, index = ray.get(self.td.get_experience.remote(experience_consumer_stage, experience_columns, experience_count, pad_id=pad_id, multiple=tp_size, - get_n_samples=get_n_samples)) # cpu数据 + get_n_samples=get_n_samples, + target_seq_len=target_seq_len)) # cpu数据 if not index: # 判断是否取出数据,未取出数据为-1 index = [-1] * experience_count - index = torch.tensor(index + ([-1] * (experience_count - len(index)))).cuda() + index = torch.tensor(index + ([-1]*(experience_count-len(index)))).cuda() else: index = torch.empty(experience_count, device=torch.cuda.current_device(), dtype=torch.int64) @@ -237,7 +239,6 @@ class BaseWorker(BaseRayWorker, ABC): if index[0].item() == -1: return None, None - index_without_pad = [] for key in experience_columns: if get_tensor_model_parallel_rank(self.parallel_state, use_vllm) == 0 and \ get_pipeline_model_parallel_rank(self.parallel_state, use_vllm) == 0: @@ -292,12 +293,11 @@ class BaseWorker(BaseRayWorker, ABC): batch_data[key].cuda(), get_pipeline_model_parallel_src_rank(self.parallel_state, use_vllm), group=get_pipeline_model_parallel_group(self.parallel_state, use_vllm) ) - if not index_without_pad: - index_without_pad = index.cpu().numpy().tolist()[:batch_data_shape[0]] - return batch_data, index_without_pad + index = index.cpu().numpy().tolist()[:batch_data_shape[0]] + return batch_data, index - def collect_transfer_dock_data(self, output, index, use_vllm=False): + def collect_transfer_dock_data(self, output, index, partial_rollout_index=None, use_vllm=False): if is_pipeline_last_stage(self.parallel_state, use_vllm) and get_tensor_model_parallel_rank(self.parallel_state, use_vllm) == 0: output = {key: value.cpu() if not isinstance(value, List) else value for key, value in output.items()} - self.td.put_experience.remote(data_dict=output, indexes=index) + self.td.put_experience.remote(data_dict=output, indexes=index, partial_rollout_indexes=partial_rollout_index) diff --git a/mindspeed_rl/workers/reward_woker.py b/mindspeed_rl/workers/reward_woker.py index 675ea439..12c518ee 100644 --- a/mindspeed_rl/workers/reward_woker.py +++ b/mindspeed_rl/workers/reward_woker.py @@ -90,8 +90,7 @@ class RewardWorkerBase(BaseWorker): batch_data, index = self.dispatch_transfer_dock_data(experience_consumer_stage, experience_columns, experience_count, - tp_size=self.megatron_config.tensor_model_parallel_size, - get_n_samples=False) + tp_size=self.megatron_config.tensor_model_parallel_size) if not start_time_defined: start_time = time.time() start_time_defined = True @@ -132,7 +131,6 @@ class RewardWorkerBase(BaseWorker): value=[round(rwd_end_time, 4)] ) ) - # self.empty_cache() @ray.remote(resources={"NPU": 0.1}) -- Gitee From cd6132137ada80924656ed096011b1df3903a340 Mon Sep 17 00:00:00 2001 From: liyongwen <1310439159@qq.com> Date: Mon, 21 Apr 2025 16:21:19 +0800 Subject: [PATCH 02/10] =?UTF-8?q?=E6=8F=90=E4=BA=A4Partial=20Rollout=20TD?= =?UTF-8?q?=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mindspeed_rl/trainer/utils/transfer_dock.py | 121 ++++++++++++++++---- 1 file changed, 97 insertions(+), 24 deletions(-) diff --git a/mindspeed_rl/trainer/utils/transfer_dock.py b/mindspeed_rl/trainer/utils/transfer_dock.py index e3685664..f15eb800 100644 --- a/mindspeed_rl/trainer/utils/transfer_dock.py +++ b/mindspeed_rl/trainer/utils/transfer_dock.py @@ -71,7 +71,7 @@ class TransferDock(ABC): def _put( self, experience_columns: List[str], - experience: List[List[List[torch.Tensor]]], + experience: List[List[torch.Tensor]], indexes: List[int] = None, ): """Put data into specified columns and rows. @@ -254,6 +254,8 @@ class GRPOTransferDock(TransferDock): self, prompts_num: int, n_samples_per_prompt: int, + max_seq_len: int = None, + max_partial_times: int = None, metrics=None, addition_columns: Union[List[str], None] = None, addition_consumers: Union[List[str], None] = None, @@ -290,6 +292,7 @@ class GRPOTransferDock(TransferDock): self.experience_consumers = [ "trainer", "actor_rollout", + "actor_partial_rollout", "actor_log_prob", "ref_log_prob", "actor_train", @@ -324,6 +327,11 @@ class GRPOTransferDock(TransferDock): for key in self.experience_consumers } self.metrics = metrics + self.max_seq_len = max_seq_len + self.max_partial_times = max_partial_times + self.enable_partial_rollout = max_seq_len and max_partial_times + self.partial_rollout_status = torch.zeros(self.max_len, dtype=torch.int32) + self.completed_rollout = torch.zeros(self.max_len, dtype=torch.int32) def get_metrics(self): return self.metrics @@ -340,6 +348,7 @@ class GRPOTransferDock(TransferDock): pad_id: int = None, multiple: int = 1, get_n_samples: bool = True, + target_seq_len: int = None ): """Get padded experience data from GRPOTransferDock. @@ -367,16 +376,26 @@ class GRPOTransferDock(TransferDock): f"get experience ERROR: {experience_column} not in TD experience_column {self.experience_columns}" ) - if indexes is None: - if experience_count > self.max_len: + if consumer == "actor_partial_rollout": + if get_n_samples: raise ValueError( - f"TD max_len: {self.max_len} need >= experience_count: {experience_count}" + "get_n_samples not supported when actor_partial_rollout" + ) + if experience_columns != ["prompts", "responses", "prompt_length"]: + raise ValueError( + "actor_partial_rollout need to get prompts and responses" ) - if self.max_len % experience_count != 0: + if indexes is None: + if experience_count > self.max_len: raise ValueError( - f"TD max_len:{self.max_len} need be divisible by experience_count: {experience_count}" + f"TD max_len: {self.max_len} need >= experience_count: {experience_count}" ) + if consumer != "actor_partial_rollout": + if self.max_len % experience_count != 0: + raise ValueError( + f"TD max_len:{self.max_len} need be divisible by experience_count: {experience_count}" + ) if get_n_samples: if experience_count % self.n_samples_per_prompt != 0: @@ -389,7 +408,7 @@ class GRPOTransferDock(TransferDock): ) else: indexes = self._sample_ready_index( - consumer, experience_count, experience_columns + consumer, experience_count, experience_columns, target_seq_len ) if not indexes: @@ -399,17 +418,34 @@ class GRPOTransferDock(TransferDock): self.experience_consumer_status[consumer][indexes] = 1 experience = self._get(experience_columns, indexes) + if consumer == "actor_partial_rollout": + self.experience_data_status["responses"][indexes] = 0 + self.partial_rollout_status[indexes] = 0 + self.experience_consumer_status[consumer][indexes] = 1 + + experience = self.post_process_partial_rollout(experience) + experience_batch = trans_experience_to_output(experience, experience_columns, pad_id, multiple) return experience_batch, indexes + @staticmethod + def post_process_partial_rollout(experience): + return [ + [torch.cat((prom, resp)) for prom, resp in zip(experience[0], experience[1])], + experience[1], + experience[2], + ] + def put_experience( self, data_dict: Dict[str, Union[Tensor, List[Tensor]]], - indexes: List[int] = None + indexes: List[int] = None, + partial_rollout_indexes=None, ): """Put data into specified columns and rows. Args: + partial_rollout_indexes: Partial rollout responses indexes. data_dict: Data dict to put in GRPOTransferDock. indexes: Rows to put data in. @@ -422,6 +458,22 @@ class GRPOTransferDock(TransferDock): "put experience into TD without indexes, indexes must be provided" ) experience_columns, experience = trans_input_to_experience(data_dict) + if "responses" in experience_columns: + if partial_rollout_indexes: + for idx, _ in enumerate(indexes): + if idx not in partial_rollout_indexes: + self.completed_rollout[indexes[idx]] = 1 + else: + if ( + experience[experience_columns.index("responses")][idx].shape[0] < self.max_seq_len + ): + self.partial_rollout_status[indexes[idx]] = 1 + self.experience_consumer_status["actor_partial_rollout"][indexes[idx]] = 0 + else: + self.completed_rollout[indexes[idx]] = 1 + else: + self.completed_rollout[indexes] = 1 + self._put(experience_columns, experience, indexes) def put_prompts_experience( @@ -486,18 +538,28 @@ class GRPOTransferDock(TransferDock): with self.consumer_sampling_lock[consumer]: not_consumed_indexes = self.experience_consumer_status[consumer] == 0 - data_ready_indexes = torch.all( - torch.stack( - [self.experience_data_status[single_column] == 1 for single_column in experience_columns] - ), dim=0, - ) + if consumer == "actor_partial_rollout": + data_ready_indexes = torch.all( + torch.stack( + [self.experience_data_status[single_column] == 1 for single_column in experience_columns] + ), dim=0, + ) & (self.partial_rollout_status == 1) + else: + data_ready_indexes = torch.all( + torch.stack( + [self.experience_data_status[single_column] == 1 for single_column in experience_columns] + ), dim=0, + ) & (self.partial_rollout_status != 1) usable_indexes = (not_consumed_indexes & data_ready_indexes).nonzero(as_tuple=True)[0] if len(usable_indexes) < experience_count: - return None + if consumer == "actor_partial_rollout": + experience_count = len(usable_indexes) + else: + return None if experience_count > 0: - sampled_indexes = self.batch_balencing_sampler( + sampled_indexes = self.batch_balancing_sampler( experience_columns, usable_indexes, experience_count, target_seq_len ) self.experience_consumer_status[consumer][sampled_indexes] = 1 @@ -519,7 +581,7 @@ class GRPOTransferDock(TransferDock): consumer: GRPO task stage to sample in. experience_count: Number for rows to sample. experience_columns: Columns from which to sample. - target_seq_len: Sample according with seq_len and target_seq_len. + target_seq_len: Sample according to seq_len and target_seq_len. Returns: Sampled row numbers. @@ -540,6 +602,8 @@ class GRPOTransferDock(TransferDock): experience_data_status_n_samples = {} for key, value in self.experience_data_status.items(): + if key == "responses" and self.enable_partial_rollout: + value = value & self.completed_rollout experience_data_status_n_samples[key] = torch.all( torch.tensor( torch.reshape(value, (self.prompts_num, self.n_samples_per_prompt)) == 1 @@ -557,7 +621,7 @@ class GRPOTransferDock(TransferDock): if len(usable_indexes) < experience_count_n_samples: return None - sampled_indexes_n_sample = self.batch_balencing_sampler( + sampled_indexes_n_sample = self.batch_balancing_sampler( experience_columns, usable_indexes, experience_count_n_samples, @@ -588,7 +652,10 @@ class GRPOTransferDock(TransferDock): Returns: True or False. """ - return self.experience_consumer_status[consumer].sum() == self.max_len + if consumer == "actor_partial_rollout": + return self.completed_rollout.sum() == self.max_len + else: + return self.experience_consumer_status[consumer].sum() == self.max_len def clear(self): """Reset consumer status.Clear data and data status in GRPOTransferDock. @@ -600,6 +667,8 @@ class GRPOTransferDock(TransferDock): key: torch.zeros(self.max_len, dtype=torch.int32) for key in self.experience_consumers } + self.partial_rollout_status = torch.zeros(self.max_len, dtype=torch.int32) + self.completed_rollout = torch.zeros(self.max_len, dtype=torch.int32) self.metrics.reset() self._clear_experience_data_and_status() @@ -611,10 +680,16 @@ class GRPOTransferDock(TransferDock): """ return self.experience_consumer_status - def batch_balencing_sampler( + def get_partial_rollout_status(self): + return self.partial_rollout_status + + def get_completed_rollout_status(self): + return self.completed_rollout + + def batch_balancing_sampler( self, experience_columns, usable_indexes, experience_count, target_seq_len=None ): - if target_seq_len is None: + if not target_seq_len: weights = torch.ones(len(usable_indexes)) else: seq_len = torch.tensor( @@ -623,7 +698,7 @@ class GRPOTransferDock(TransferDock): for idx in usable_indexes ] ) - weights = torch.sigmoid(1 / (torch.abs(seq_len - target_seq_len) + 0.001), dim=0) + weights = torch.sigmoid(1 / (torch.abs(seq_len - target_seq_len[0]) + 0.001)) sampled_indexes_idx = torch.multinomial(weights, experience_count, replacement=False).tolist() sampled_indexes = [int(usable_indexes[i]) for i in sampled_indexes_idx] @@ -709,8 +784,6 @@ def trans_input_to_experience(experience_dict: Dict[str, Union[Tensor, List[Tens tensor([3, 3, 3]), tensor([4, 4, 4, 4])] } - num_responses: The number of data to put in each row. - 2 Returns: Columns and data list. ['prompts', 'attention_mask'] @@ -722,7 +795,7 @@ def trans_input_to_experience(experience_dict: Dict[str, Union[Tensor, List[Tens tensor([4, 4, 4, 4]) ], [ - tensor([1)], + tensor([1]), tensor([2, 2]), tensor([3, 3, 3]), tensor([4, 4, 4, 4]) -- Gitee From 52d1d633e804acb47dc4c87bf00edb3cd3d4ea69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E6=97=AD=E6=97=B6?= Date: Mon, 21 Apr 2025 17:25:03 +0800 Subject: [PATCH 03/10] =?UTF-8?q?=E6=8F=90=E4=BA=A4partial=20rollout=20cli?= =?UTF-8?q?=E9=80=82=E9=85=8D=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cli/train_grpo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cli/train_grpo.py b/cli/train_grpo.py index dbae265d..d097da15 100644 --- a/cli/train_grpo.py +++ b/cli/train_grpo.py @@ -149,6 +149,8 @@ def train(config): train_iters=actor_config.train_iters, save_interval=actor_config.save_interval, dataset_additional_keys=actor_config.dataset_additional_keys, + max_seq_len=generate_config.max_seq_len, + max_partial_times=generate_config.max_partial_times, **rl_config.dict() ) @@ -220,8 +222,6 @@ def get_megatron_module(): 'local_ddp': LocalDDP, 'distributed_data_parallel_config': DistributedDataParallelConfig, 'vocab_parallel_cross_entropy': vocab_parallel_cross_entropy, - 'setup_model_and_optimizer': setup_model_and_optimizer, - 'model_type': ModelType, } -- Gitee From 6fa2de66bdba011eab08f9eec4be8e735f96e467 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E6=97=AD=E6=97=B6?= Date: Mon, 21 Apr 2025 18:12:20 +0800 Subject: [PATCH 04/10] bug fix --- mindspeed_rl/workers/actor_hybrid_worker.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mindspeed_rl/workers/actor_hybrid_worker.py b/mindspeed_rl/workers/actor_hybrid_worker.py index cb171633..c91e8daf 100644 --- a/mindspeed_rl/workers/actor_hybrid_worker.py +++ b/mindspeed_rl/workers/actor_hybrid_worker.py @@ -267,10 +267,8 @@ class ActorHybridWorkerBase(BaseWorker): prompt_length = batch_data['prompt_length'] if pre_iter_avg_seq_length and pre_iter_std_seq_length and max_tokens_first_stage: max_tokens = (self.generate_config.max_seq_len - max_tokens_first_stage) // self.generate_config.max_partial_times - self.print0(f"[tmp log][maxt partial tokens second stage]{max_tokens}") else: max_tokens = (self.generate_config.max_seq_len - self.generate_config.sampling_config["max_tokens"]) // self.generate_config.max_partial_times - self.print0(f"[tmp log][maxt partial tokens second stage]{max_tokens}") second_satge_resps = self.actor_hybrid.generate_sequences( copy.deepcopy(prompts_list), -- Gitee From 4c85c7f0d2fd7a8ed2dd95f1aa82f734ee844b37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E6=97=AD=E6=97=B6?= Date: Wed, 23 Apr 2025 14:47:13 +0800 Subject: [PATCH 05/10] fix merge bug --- mindspeed_rl/workers/actor_hybrid_worker.py | 3 +-- mindspeed_rl/workers/base_worker.py | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/mindspeed_rl/workers/actor_hybrid_worker.py b/mindspeed_rl/workers/actor_hybrid_worker.py index 0251469c..da671ad1 100644 --- a/mindspeed_rl/workers/actor_hybrid_worker.py +++ b/mindspeed_rl/workers/actor_hybrid_worker.py @@ -130,7 +130,7 @@ class ActorHybridWorkerBase(BaseWorker): ray.get(self.td.update_metrics.remote(key='grpo/lr', value=learning_rate)) start_time_defined = False - count = 0 + while self.all_consumed(experience_consumer_stage) > 0: batch_data, index = self.dispatch_transfer_dock_data(experience_consumer_stage, experience_columns, @@ -143,7 +143,6 @@ class ActorHybridWorkerBase(BaseWorker): if batch_data and index: metrics = self.actor_hybrid.update_actor(batch_data, kl_ctrl) - # self.empty_cache() self.args.consumed_train_samples += self.megatron_config.global_batch_size // self.rl_config.n_samples_per_prompt self.num_floating_point_operations_so_far += num_floating_point_operations(self.args, self.megatron_config.global_batch_size) diff --git a/mindspeed_rl/workers/base_worker.py b/mindspeed_rl/workers/base_worker.py index a5eec31d..1a955943 100644 --- a/mindspeed_rl/workers/base_worker.py +++ b/mindspeed_rl/workers/base_worker.py @@ -218,7 +218,6 @@ class BaseWorker(BaseRayWorker, ABC): get_pipeline_model_parallel_rank(self.parallel_state, use_vllm) == 0: batch_data, index = ray.get(self.td.get_experience.remote(experience_consumer_stage, experience_columns, experience_count, - multiple=tp_size, get_n_samples=get_n_samples, target_seq_len=target_seq_len)) # cpu数据 if not index: # 判断是否取出数据,未取出数据为-1 -- Gitee From 7a681d359bcf8b3f003f02b01e6eb81ee34648e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E6=97=AD=E6=97=B6?= Date: Fri, 25 Apr 2025 11:21:57 +0800 Subject: [PATCH 06/10] fix remove pad bug --- mindspeed_rl/trainer/utils/transfer_dock.py | 22 ++++----------------- 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/mindspeed_rl/trainer/utils/transfer_dock.py b/mindspeed_rl/trainer/utils/transfer_dock.py index eecfbb74..43715546 100644 --- a/mindspeed_rl/trainer/utils/transfer_dock.py +++ b/mindspeed_rl/trainer/utils/transfer_dock.py @@ -833,7 +833,7 @@ def trans_input_to_experience(experience_dict: Dict[str, Union[Tensor, List[Tens return experience_columns, experience_list -def pack_experience_columns(experience_dict, experience_count): +def pack_experience_columns(experience_dict): """ Compress experiences by packing tensors into ONE. from experience_dict @@ -858,25 +858,11 @@ def pack_experience_columns(experience_dict, experience_count): 'attention_mask': tensor([1, 2, 3, 4]) } """ - - if not experience_dict: - raise ValueError(f"ERROR: when pack, get an empty experience_dict") - batch_data = {} batch_data_length = {} - - for key, value in experience_dict.items(): - if len(value) != experience_count: - raise ValueError(f"ERROR: when pack, experience '{key}' number does not match experience_count") - packed_experience = [] - data_length = [] - for i in range(experience_count): - packed_experience.extend(value[i].tolist()) - data_length.append(len(value[i])) - - batch_data[key] = torch.tensor(packed_experience, dtype=value[0].dtype) - batch_data_length[key] = torch.tensor(data_length, dtype=torch.int32) - + for column, experiences in experience_dict.items(): + batch_data[column] = torch.cat(experiences) + batch_data_length[column] = torch.tensor([len(e) for e in experiences], dtype=torch.int32) return batch_data, batch_data_length -- Gitee From b76149246054c8d3d714efc392c9ce7546f6741c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E6=97=AD=E6=97=B6?= Date: Tue, 29 Apr 2025 16:13:25 +0800 Subject: [PATCH 07/10] bug fix --- mindspeed_rl/workers/actor_hybrid_worker.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mindspeed_rl/workers/actor_hybrid_worker.py b/mindspeed_rl/workers/actor_hybrid_worker.py index da671ad1..59356432 100644 --- a/mindspeed_rl/workers/actor_hybrid_worker.py +++ b/mindspeed_rl/workers/actor_hybrid_worker.py @@ -171,7 +171,8 @@ class ActorHybridWorkerBase(BaseWorker): max_tokens_first_stage = None if pre_iter_avg_seq_length and pre_iter_std_seq_length: k = self.generate_config.k - max_tokens_first_stage = round(pre_iter_avg_seq_length[0] + pre_iter_std_seq_length[0] * k) + max_tokens_first_stage = min(round(pre_iter_avg_seq_length[0] + pre_iter_std_seq_length[0] * k), + self.generate_config.max_seq_len) self.sharding_manager.enter_infer_mode() experience_consumer_stage = 'actor_rollout' -- Gitee From 5dbc616a2f53298c1ce79485f1d3f051a9b85769 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E6=97=AD=E6=97=B6?= Date: Sat, 10 May 2025 19:24:25 +0800 Subject: [PATCH 08/10] =?UTF-8?q?=E7=A7=BB=E9=99=A4total=5Fpartial=5Frollo?= =?UTF-8?q?ut=5Findexes=EF=BC=8C=E6=9B=B4=E6=8D=A2DP=E5=9D=87=E8=A1=A1?= =?UTF-8?q?=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mindspeed_rl/trainer/grpo_trainer_hybrid.py | 6 ++--- mindspeed_rl/workers/actor_hybrid_worker.py | 25 +++++---------------- 2 files changed, 8 insertions(+), 23 deletions(-) diff --git a/mindspeed_rl/trainer/grpo_trainer_hybrid.py b/mindspeed_rl/trainer/grpo_trainer_hybrid.py index ff612b34..8ff87c5c 100644 --- a/mindspeed_rl/trainer/grpo_trainer_hybrid.py +++ b/mindspeed_rl/trainer/grpo_trainer_hybrid.py @@ -137,7 +137,6 @@ class RayGRPOTrainer(RayBaseTrainer): if enable_partial_rollout: ray.get(self.transfer_dock.update_metrics.remote("response_length/std", global_response_std, cumulate=True)) ray.get(self.transfer_dock.update_metrics.remote("response_length/mean", global_response_mean, cumulate=True)) - ray.get(self.transfer_dock.update_metrics.remote('total_partial_rollout_indexes', global_num_partial_rollout, cumulate=True)) global_num_partial_rollout_len = len(global_num_partial_rollout) self.actor_worker.generate_sequences(blocking=self.blocking) @@ -172,13 +171,12 @@ class RayGRPOTrainer(RayBaseTrainer): self.global_batch_size * self.n_samples_per_prompt, self.tokenizer) metrics_result = ray.get(self.transfer_dock.get_metrics.remote()) - if enable_partial_rollout: + if enable_partial_rollout: global_response_std = [grpo_data_metrics['response_length/std']] global_response_mean = [grpo_data_metrics[ 'response_length/mean']] metrics_result.metric["response_length/std"] = global_response_std metrics_result.metric["response_length/mean"] = global_response_mean - local_num_partial_rollout = metrics_result.metric['total_partial_rollout_indexes'] - global_num_partial_rollout = local_num_partial_rollout[global_num_partial_rollout_len:] + metrics_result = metrics_post_processing(metrics_result) metrics_result = metrics_sort(metrics_result, all_timer.last) tps = compute_tps(self.kwargs, grpo_data_metrics, self.global_batch_size, self.n_samples_per_prompt, all_timer.last) diff --git a/mindspeed_rl/workers/actor_hybrid_worker.py b/mindspeed_rl/workers/actor_hybrid_worker.py index 59356432..9b16a8fe 100644 --- a/mindspeed_rl/workers/actor_hybrid_worker.py +++ b/mindspeed_rl/workers/actor_hybrid_worker.py @@ -183,7 +183,6 @@ class ActorHybridWorkerBase(BaseWorker): pad_token_id = self.tokenizer.pad if self.tokenizer.pad else self.tokenizer.eod start_time_defined = False - total_partial_rollout_indexes = [] while self.all_consumed(experience_consumer_stage, use_vllm=True) > 0: batch_data, index = self.dispatch_transfer_dock_data( experience_consumer_stage, @@ -214,7 +213,6 @@ class ActorHybridWorkerBase(BaseWorker): for idx, response in enumerate(responses): if response[-1] != self.tokenizer.eod: partial_rollout_indexes.append(idx) - total_partial_rollout_indexes.append(index[idx]) responses_length = [torch.tensor([len(response)]) for response in responses] @@ -244,14 +242,11 @@ class ActorHybridWorkerBase(BaseWorker): ) if self.generate_config.enable_partial_rollout: - metrics_result = ray.get(self.td.get_metrics.remote()) - pre_total_partial_rollout_indexes = metrics_result.metric['total_partial_rollout_indexes'] - - if pre_total_partial_rollout_indexes: - dp_num = self.generate_config.data_parallel_size - experience_count = math.ceil(len(pre_total_partial_rollout_indexes)/dp_num) - else: - experience_count = self.generate_config.partial_micro_batch_size + torch.distributed.barrier() + completed_rollout = ray.get(self.td.get_completed_rollout_status.remote()) + partial_resp_num = torch.sum(completed_rollout == 0).item() + dp_num = self.generate_config.data_parallel_size + experience_count = max(math.ceil(partial_resp_num / dp_num), 1) while self.all_consumed('actor_partial_rollout', use_vllm=True) > 0: batch_data, index = self.dispatch_transfer_dock_data( @@ -311,7 +306,7 @@ class ActorHybridWorkerBase(BaseWorker): ) ) self.collect_transfer_dock_data(outputs, index, partial_rollout_indexes, use_vllm=True) - from datetime import datetime + generate_end_time = time.time() parallel_state = get_parallel_state() use_vllm = True @@ -323,14 +318,6 @@ class ActorHybridWorkerBase(BaseWorker): cumulate=True ) ) - if self.generate_config.enable_partial_rollout: - ray.get( - self.td.update_metrics.remote( - "total_partial_rollout_indexes", - value=total_partial_rollout_indexes, - cumulate=True - ) - ) self.sharding_manager.exit_infer_mode() -- Gitee From 1c89fd8957d5c0e82630a62f18015b4f23fac457 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E6=97=AD=E6=97=B6?= Date: Sat, 10 May 2025 19:27:14 +0800 Subject: [PATCH 09/10] =?UTF-8?q?=E5=88=A0=E9=99=A4=E5=A4=9A=E4=BD=99?= =?UTF-8?q?=E5=86=85=E5=AE=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mindspeed_rl/trainer/grpo_trainer_hybrid.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mindspeed_rl/trainer/grpo_trainer_hybrid.py b/mindspeed_rl/trainer/grpo_trainer_hybrid.py index 8ff87c5c..99297cf3 100644 --- a/mindspeed_rl/trainer/grpo_trainer_hybrid.py +++ b/mindspeed_rl/trainer/grpo_trainer_hybrid.py @@ -124,8 +124,6 @@ class RayGRPOTrainer(RayBaseTrainer): if enable_partial_rollout: global_response_std = [] global_response_mean = [] - global_num_partial_rollout = [] - global_num_partial_rollout_len = 0 while iteration < self.train_iters: ray.get(self.transfer_dock.clear.remote()) @@ -137,7 +135,6 @@ class RayGRPOTrainer(RayBaseTrainer): if enable_partial_rollout: ray.get(self.transfer_dock.update_metrics.remote("response_length/std", global_response_std, cumulate=True)) ray.get(self.transfer_dock.update_metrics.remote("response_length/mean", global_response_mean, cumulate=True)) - global_num_partial_rollout_len = len(global_num_partial_rollout) self.actor_worker.generate_sequences(blocking=self.blocking) # compute rm scores. -- Gitee From a7d937d3c816ce38be2783fbe1193aa5f5fb5cb6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E6=97=AD=E6=97=B6?= Date: Sat, 10 May 2025 19:56:29 +0800 Subject: [PATCH 10/10] =?UTF-8?q?=E6=8F=90=E4=BA=A4=E7=99=BE=E5=88=86?= =?UTF-8?q?=E4=BD=8D=E4=BB=A3=E7=A0=81=E4=BB=A5=E5=8F=8A=E5=88=A0=E9=99=A4?= =?UTF-8?q?mean/std=E8=AE=A1=E7=AE=97=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mindspeed_rl/trainer/grpo_trainer_hybrid.py | 12 ++++-------- mindspeed_rl/trainer/utils/compute_utils.py | 3 ++- mindspeed_rl/workers/actor_hybrid_worker.py | 16 ++++++---------- 3 files changed, 12 insertions(+), 19 deletions(-) diff --git a/mindspeed_rl/trainer/grpo_trainer_hybrid.py b/mindspeed_rl/trainer/grpo_trainer_hybrid.py index 03bff9de..916d660e 100644 --- a/mindspeed_rl/trainer/grpo_trainer_hybrid.py +++ b/mindspeed_rl/trainer/grpo_trainer_hybrid.py @@ -132,8 +132,7 @@ class RayGRPOTrainer(RayBaseTrainer): enable_partial_rollout = self.max_partial_times and self.max_seq_len if enable_partial_rollout: - global_response_std = [] - global_response_mean = [] + global_response_quantile = [] while iteration < self.train_iters: ray.get(self.transfer_dock.clear.remote()) @@ -143,8 +142,7 @@ class RayGRPOTrainer(RayBaseTrainer): with Timer(name='iteration', logger=None) as all_timer: # generate sequences if enable_partial_rollout: - ray.get(self.transfer_dock.update_metrics.remote("response_length/std", global_response_std, cumulate=True)) - ray.get(self.transfer_dock.update_metrics.remote("response_length/mean", global_response_mean, cumulate=True)) + ray.get(self.transfer_dock.update_metrics.remote("response_length/0.9quantile", global_response_quantile, cumulate=True)) self.actor_worker.generate_sequences(blocking=self.blocking) # compute rm scores. @@ -183,10 +181,8 @@ class RayGRPOTrainer(RayBaseTrainer): self.guarantee_order) metrics_result = ray.get(self.transfer_dock.get_metrics.remote()) if enable_partial_rollout: - global_response_std = [grpo_data_metrics['response_length/std']] - global_response_mean = [grpo_data_metrics[ 'response_length/mean']] - metrics_result.metric["response_length/std"] = global_response_std - metrics_result.metric["response_length/mean"] = global_response_mean + global_response_quantile = [grpo_data_metrics['response_length/0.9quantile']] + metrics_result.metric["response_length/0.9quantile"] = global_response_quantile metrics_result = metrics_post_processing(metrics_result) metrics_result = metrics_sort(metrics_result, all_timer.last) diff --git a/mindspeed_rl/trainer/utils/compute_utils.py b/mindspeed_rl/trainer/utils/compute_utils.py index fbf485f3..20b9331f 100644 --- a/mindspeed_rl/trainer/utils/compute_utils.py +++ b/mindspeed_rl/trainer/utils/compute_utils.py @@ -249,7 +249,8 @@ def compute_grpo_data_metrics( "response_length/mean": torch.mean(response_length, dtype=torch.float32).detach().item(), "response_length/max": torch.max(response_length).detach().item(), "response_length/min": torch.min(response_length).detach().item(), - "response_length/std": torch.std(response_length.to(torch.float32)).detach().item(), + "response_length/0.8quantile": torch.quantile(response_length.to(torch.float32), 0.8).detach().item(), + "response_length/0.9quantile": torch.quantile(response_length.to(torch.float32), 0.9).detach().item(), # prompt length "prompt_length/mean": torch.mean(prompt_length, dtype=torch.float32).detach().item(), "prompt_length/max": torch.max(prompt_length).detach().item(), diff --git a/mindspeed_rl/workers/actor_hybrid_worker.py b/mindspeed_rl/workers/actor_hybrid_worker.py index f7fc9090..6990d5d7 100644 --- a/mindspeed_rl/workers/actor_hybrid_worker.py +++ b/mindspeed_rl/workers/actor_hybrid_worker.py @@ -184,13 +184,10 @@ class ActorHybridWorkerBase(BaseWorker): def generate_sequences(self): if self.generate_config.enable_partial_rollout: metrics_result = ray.get(self.td.get_metrics.remote()) - pre_iter_avg_seq_length = metrics_result.metric['response_length/mean'] - pre_iter_std_seq_length = metrics_result.metric['response_length/std'] + pre_iter_qt_seq_length = metrics_result.metric['response_length/0.9quantile'] max_tokens_first_stage = None - if pre_iter_avg_seq_length and pre_iter_std_seq_length: - k = self.generate_config.k - max_tokens_first_stage = min(round(pre_iter_avg_seq_length[0] + pre_iter_std_seq_length[0] * k), - self.generate_config.max_seq_len) + if pre_iter_qt_seq_length: + max_tokens_first_stage = round(pre_iter_qt_seq_length[0]) start_sharding_enter_infer = time.time() self.sharding_manager.enter_infer_mode() @@ -227,7 +224,7 @@ class ActorHybridWorkerBase(BaseWorker): prompts_list = [prompt.numpy().tolist() for prompt in prompts] # inference - if self.generate_config.enable_partial_rollout and pre_iter_avg_seq_length and pre_iter_std_seq_length and max_tokens_first_stage: + if self.generate_config.enable_partial_rollout and pre_iter_qt_seq_length and max_tokens_first_stage: responses = self.actor_hybrid.generate_sequences(copy.deepcopy(prompts_list), max_tokens=max_tokens_first_stage) else: responses = self.actor_hybrid.generate_sequences(copy.deepcopy(prompts_list)) @@ -279,8 +276,7 @@ class ActorHybridWorkerBase(BaseWorker): experience_count, tp_size=self.generate_config.infer_tensor_parallel_size, use_vllm=True, - get_n_samples=False, - target_seq_len=pre_iter_avg_seq_length + get_n_samples=False ) if batch_data and index: prompts = batch_data['prompts'] @@ -288,7 +284,7 @@ class ActorHybridWorkerBase(BaseWorker): first_stage_resps = remove_padding_and_split_to_list(batch_data['responses'], self.tokenizer.eod, pad_token_id) prompt_length = batch_data['prompt_length'] - if pre_iter_avg_seq_length and pre_iter_std_seq_length and max_tokens_first_stage: + if pre_iter_qt_seq_length and max_tokens_first_stage: max_tokens = (self.generate_config.max_seq_len - max_tokens_first_stage) // self.generate_config.max_partial_times else: max_tokens = (self.generate_config.max_seq_len - self.generate_config.sampling_config["max_tokens"]) // self.generate_config.max_partial_times -- Gitee