diff --git a/configs/model/templates.json b/configs/model/templates.json index 8b3142679cbc93f08186c9c1da9218fc83d5079d..f32e4967836ca40ffcb62f6075565bdb2d34dc5a 100644 --- a/configs/model/templates.json +++ b/configs/model/templates.json @@ -132,6 +132,33 @@ ], "replace_eos": true }, + { + "name": "qwq", + "format_user": { + "slots": [ + "<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n\n" + ] + }, + "format_system": { + "slots": [ + "<|im_start|>system\n{{content}}<|im_end|>\n" + ] + }, + "format_observation": { + "slots": [ + "<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n\n" + ] + }, + "format_separator": { + "slots": [ + "\n" + ] + }, + "stop_words": [ + "<|im_end|>" + ], + "replace_eos": true + }, { "name": "llama3", "format_user": { diff --git a/mindspeed_rl/config_cls/rl_config.py b/mindspeed_rl/config_cls/rl_config.py index b743b40fcbbfb34a854c895db15b04dabf35fea7..56fa037c1a093cee866c2a8dfddcdc7d7db4703e 100644 --- a/mindspeed_rl/config_cls/rl_config.py +++ b/mindspeed_rl/config_cls/rl_config.py @@ -45,6 +45,7 @@ class RLConfig(BaseConfig): wandb_exp_name: The wandb experiment name. If use_wandb is True, you need to set the wandb experiment name (default: "") wandb_save_dir: Path to save the wandb results locally. (default: "") blocking: Whether to enable blocking mode (default: False) + async_engine: Whether to enable the asynchronous generate process, which means that sample completed first will come out first. (default: False) num_cpus_for_local_task: Number of CPUs for local ray task (default: 1) num_cpus_for_placement_group: Number of CPUs for ray worker placement group # Default values can still be defined if no config is provided @@ -76,7 +77,8 @@ class RLConfig(BaseConfig): self.verifier_weight = [1.0, ] self.verifier_parallel = 1 self.verifier_timeout = 30 - + + self.async_engine = False self.experience_count = 1 self.experience_count_actor = None self.experience_count_reward = None diff --git a/mindspeed_rl/models/actor_rollout_hybrid.py b/mindspeed_rl/models/actor_rollout_hybrid.py index 7f66cff98b6c3eca465c43081dd7d481a6865eab..8288bfc4d0ccffc5655e5f9f28cea652870b4f84 100644 --- a/mindspeed_rl/models/actor_rollout_hybrid.py +++ b/mindspeed_rl/models/actor_rollout_hybrid.py @@ -64,9 +64,17 @@ class ActorRolloutHybrid(ABC): self.inference_actor = inference_model self.sharding_manager = sharding_manager - def generate_sequences(self, prompts_list: List[List[int]]) -> Tensor: - responses = self.inference_actor.generate_sequences(prompts_list)[0] - return responses + def generate_sequences(self, prompts_list: List[List[int]], indexs=None, n_samples_per_prompt=None, async_engine=False) -> Tensor: + if async_engine: + res = self.inference_actor.async_generate_sequences( + prompts_list, + indexs, + n_samples_per_prompt=n_samples_per_prompt + ) + else: + res = self.inference_actor.generate_sequences(prompts_list)[0] + + return res def compute_log_prob(self, data: Dict) -> Tensor: return self.train_actor.compute_log_prob(data) diff --git a/mindspeed_rl/models/rollout/vllm_engine.py b/mindspeed_rl/models/rollout/vllm_engine.py index 16ce4c31b84842730bdad8e945f82d8592705417..fc834a99c8419b4deb2bc1a1c5d3ba25daff7057 100644 --- a/mindspeed_rl/models/rollout/vllm_engine.py +++ b/mindspeed_rl/models/rollout/vllm_engine.py @@ -3,6 +3,7 @@ # Copyright 2023 The vLLM team. import os +import uuid from contextlib import contextmanager import gc @@ -170,6 +171,7 @@ class VLLMInferEngine(BaseInferEngine): max_model_len=max_model_len ) + self.engine = self.llm.llm_engine self.model = self.llm.llm_engine.model_executor.driver_worker.worker.model_runner.get_model() self.cpu_model = {} @@ -250,6 +252,28 @@ class VLLMInferEngine(BaseInferEngine): self.free_cache_engine() return outs + @torch.no_grad() + def async_generate_sequences(self, idx_list, indexs, n_samples_per_prompt=None, **kwargs): + with self.update_sampling_params(**kwargs): + for i, prompt_token_ids in enumerate(idx_list): + request_id = f"req_{indexs[i]}_{uuid.uuid4().hex[:6]}" + + self.engine.add_request( + request_id=request_id, + prompt={"prompt_token_ids": prompt_token_ids}, + params=self.sampling_params + ) + + while self.engine.has_unfinished_requests(): + step_outputs = self.engine.step() + for output in step_outputs: + if output.finished: + request_id = output.request_id + index = int(request_id.split("_")[1]) + prompt_ids = [torch.tensor(prompt_token_ids) for _ in range(n_samples_per_prompt)] + response_ids = self._post_process_outputs([output]) + yield (prompt_ids, *response_ids), index + def _post_process_outputs(self, request_outputs): output_token_ids = [] logprobs = [] diff --git a/mindspeed_rl/trainer/base.py b/mindspeed_rl/trainer/base.py index fcf5611322cab231366e534667285b857e5b8cf4..f4ab2114b2340aba1109a125acdbb4558cc27e41 100644 --- a/mindspeed_rl/trainer/base.py +++ b/mindspeed_rl/trainer/base.py @@ -37,6 +37,7 @@ class RayBaseTrainer(object): tokenizer: BaseTokenizer = None, dataset_additional_keys: List[str] = None, blocking: bool = False, + async_engine: bool = False, num_cpus_for_local_task: float = 0.1, **kwargs): @@ -60,6 +61,7 @@ class RayBaseTrainer(object): self.tokenizer = tokenizer self.dataset_additional_keys = dataset_additional_keys self.blocking = blocking + self.async_engine = async_engine self.num_cpus_for_local_task = num_cpus_for_local_task self.kwargs = kwargs diff --git a/mindspeed_rl/trainer/grpo_trainer_hybrid.py b/mindspeed_rl/trainer/grpo_trainer_hybrid.py index 079a3f30570d693d626d9cb20a9a6273081bebba..a90c779ec4c908add6cccb481b79ff103ee57136 100644 --- a/mindspeed_rl/trainer/grpo_trainer_hybrid.py +++ b/mindspeed_rl/trainer/grpo_trainer_hybrid.py @@ -146,7 +146,10 @@ class RayGRPOTrainer(RayBaseTrainer): with Timer(name='iteration', logger=None) as all_timer: # generate sequences - self.actor_worker.generate_sequences(blocking=self.blocking) + if self.async_engine: + self.actor_worker.async_generate_sequences(blocking=self.blocking) + else: + self.actor_worker.generate_sequences(blocking=self.blocking) # compute rm scores. for reward_worker in self.reward_list: diff --git a/mindspeed_rl/trainer/utils/parallel_state.py b/mindspeed_rl/trainer/utils/parallel_state.py index da714f044699175ad98391adc0299e0fb40b7ca9..1c3043435d545d0585834fa5b66df8e94308097c 100644 --- a/mindspeed_rl/trainer/utils/parallel_state.py +++ b/mindspeed_rl/trainer/utils/parallel_state.py @@ -59,9 +59,15 @@ def get_tensor_model_parallel_group(mpu, use_vllm=False): def get_model_parallel_group(mpu, use_vllm=False): if use_vllm: + import vllm from vllm.distributed import parallel_state as vpu - if not hasattr(vpu, "get_tensor_model_parallel_group"): - vpu = mpu - return vpu.get_model_parallel_group() + + if vllm.__version__ == "0.7.3": + # In 0.7.3, vllm's offline inference api only support tensor parallelism + return vpu.get_tensor_model_parallel_group().device_group + else: + if not hasattr(vpu, "get_tensor_model_parallel_group"): + vpu = mpu + return vpu.get_model_parallel_group() else: return mpu.get_model_parallel_group() \ No newline at end of file diff --git a/mindspeed_rl/workers/actor_hybrid_worker.py b/mindspeed_rl/workers/actor_hybrid_worker.py index 96987d07fb3cfa44f757813843fb0b945aa50553..6ea2ab5822aa4062a7fdba7dbc0dcae1580654f0 100644 --- a/mindspeed_rl/workers/actor_hybrid_worker.py +++ b/mindspeed_rl/workers/actor_hybrid_worker.py @@ -163,7 +163,7 @@ class ActorHybridWorkerBase(BaseWorker): pad_token_id = self.tokenizer.pad if self.tokenizer.pad else self.tokenizer.eod start_time_defined = False - while self.all_consumed(experience_consumer_stage) > 0: + while self.all_consumed(experience_consumer_stage, use_vllm=True) > 0: batch_data, index = self.dispatch_transfer_dock_data( experience_consumer_stage, experience_colums, @@ -228,6 +228,90 @@ class ActorHybridWorkerBase(BaseWorker): self.sharding_manager.exit_infer_mode() + def async_generate_sequences(self): + self.sharding_manager.enter_infer_mode() + + experience_consumer_stage = 'actor_rollout' + experience_colums = ['prompts', 'prompt_length'] + experience_count = self.rl_config.experience_count_actor // self.generate_config.data_parallel_size + + pad_token_id = self.tokenizer.pad if self.tokenizer.pad else self.tokenizer.eod + + start_time_defined = False + while self.all_consumed(experience_consumer_stage, use_vllm=True) > 0: + batch_data, index = self.dispatch_transfer_dock_data( + experience_consumer_stage, + experience_colums, + experience_count, + n_samples_per_prompt=self.rl_config.n_samples_per_prompt, + tp_size=self.megatron_config.tensor_model_parallel_size, + use_vllm=True + ) + + if not start_time_defined: + start_time = time.time() + start_time_defined = True + + if batch_data and index: + indexes = list(range(0, experience_count * self.rl_config.n_samples_per_prompt, + self.rl_config.n_samples_per_prompt)) + prompts_data = batch_data['prompts'][indexes] + prompt_length_data = batch_data['prompt_length'][indexes] + + # preprocess, remove padding + prompts = truncate_rows(prompts_data, prompt_length_data) + prompts_list = [prompt.numpy().tolist() for prompt in prompts] + + # inference + self.actor_hybrid.inference_actor.init_cache_engine() + response_generator = self.actor_hybrid.generate_sequences( + copy.deepcopy(prompts_list), + indexs=copy.deepcopy(index), + n_samples_per_prompt=self.rl_config.n_samples_per_prompt, + async_engine=True, + ) + + for samples, idx in response_generator: + prompts, responses, log_probs = samples + responses = remove_padding_and_split_to_list(responses, self.tokenizer.eod, pad_token_id) + responses_length = [torch.tensor([len(response)]) for response in responses] + + input_ids_list = [] + for prompt, response in zip(prompts, responses): + input_ids_list.append(torch.cat((prompt, response), dim=0)) + + outputs = { + 'responses': responses, + 'input_ids': input_ids_list, + 'response_length': responses_length + } + + self.collect_transfer_dock_data(outputs, [idx], self.rl_config.n_samples_per_prompt, use_vllm=True) + + end_time = time.time() + ray.get( + self.td.update_metrics.remote( + "timing/rollout", + value=[round(end_time, 4), round(start_time, 4)], + cumulate=True + ) + ) + self.actor_hybrid.inference_actor.free_cache_engine() + + generate_end_time = time.time() + parallel_state = get_parallel_state() + use_vllm = True + if is_pipeline_last_stage(parallel_state, use_vllm) and get_tensor_model_parallel_rank(parallel_state, use_vllm) == 0: + ray.get( + self.td.update_metrics.remote( + "end_time/generate", + value=[round(generate_end_time, 4)], + cumulate=True + ) + ) + + self.sharding_manager.exit_infer_mode() + def compute_log_prob(self): self.sharding_manager.enter_forward_mode() diff --git a/mindspeed_rl/workers/base_worker.py b/mindspeed_rl/workers/base_worker.py index 4b0a7c2bf8c2ba4921d684df31924b0195e32424..6fed82cb092a14c13be0a63f0dea23bf5f6a6c5e 100644 --- a/mindspeed_rl/workers/base_worker.py +++ b/mindspeed_rl/workers/base_worker.py @@ -145,11 +145,15 @@ class BaseWorker(BaseRayWorker, ABC): self.args = None def all_consumed(self, experience_consumer_stage, use_vllm=False): - status = torch.tensor(0, device=next(self.model[0].parameters()).device) + if use_vllm: + current_device = next(self.inference_model.model.parameters()).device + else: + current_device = next(self.model[0].parameters()).device + status = torch.tensor(0, device=current_device) if get_tensor_model_parallel_rank(self.parallel_state, use_vllm) == 0 and \ get_pipeline_model_parallel_rank(self.parallel_state, use_vllm) == 0: status = torch.tensor(int(not ray.get(self.td.all_consumed.remote(experience_consumer_stage))), - device=next(self.model[0].parameters()).device) + device=current_device) torch.distributed.all_reduce(status, group=get_model_parallel_group(self.parallel_state, use_vllm), op=torch.distributed.ReduceOp.MAX) return status diff --git a/mindspeed_rl/workers/scheduler/launcher.py b/mindspeed_rl/workers/scheduler/launcher.py index ceaed6fb44985a3976df4f5cb666d6da7c369673..e432909c353ddfb2c452ea8d663beb62b7f38ce6 100644 --- a/mindspeed_rl/workers/scheduler/launcher.py +++ b/mindspeed_rl/workers/scheduler/launcher.py @@ -260,6 +260,12 @@ class RayActorGroup: if blocking: ray.get(self.temp_actor_ref_objs) + def async_generate_sequences(self, blocking=False): + for actor in self.actor_handlers: + self.temp_actor_ref_objs.append(actor.async_generate_sequences.remote()) + if blocking: + ray.get(self.temp_actor_ref_objs) + def compute_log_prob(self, blocking=False): for actor in self.actor_handlers: self.temp_actor_ref_objs.append(actor.compute_log_prob.remote())