From 0a14c510e5efbb351152d10291a49742ad2ebc87 Mon Sep 17 00:00:00 2001 From: zhuyilin Date: Mon, 21 Apr 2025 15:31:07 +0800 Subject: [PATCH] reward --- cli/train_grpo.py | 13 ++- mindspeed_rl/config_cls/rl_config.py | 1 + mindspeed_rl/trainer/grpo_trainer_hybrid.py | 10 +- mindspeed_rl/utils/kwai.py | 115 ++++++++++++++++---- 4 files changed, 111 insertions(+), 28 deletions(-) diff --git a/cli/train_grpo.py b/cli/train_grpo.py index dbae265d..773c5643 100644 --- a/cli/train_grpo.py +++ b/cli/train_grpo.py @@ -32,6 +32,7 @@ from mindspeed_rl.workers.actor_hybrid_worker import ActorHybridWorker from mindspeed_rl.workers.reference_woker import ReferenceWorker from mindspeed_rl.workers.reward_woker import RewardWorker from mindspeed_rl.workers.integrated_worker import IntegratedWorker +from ray.util.placement_group import placement_group cur_file_dir = Path(__file__).absolute().parent.parent logger = Loggers("grpo_train") @@ -107,9 +108,15 @@ def train(config): reward_list.append(reward_worker) if rl_config.rule_reward: - rule_reward = RuleReward.options(num_cpus=rl_config.num_cpus_for_local_task).remote() - rule_reward.initialize.remote(reward_config, rl_config, tokenizer) - reward_list.append(rule_reward) + pg = placement_group( + [{"CPU": rl_config.num_cpus_for_local_task} for _ in range(rl_config.rule_reward_num_process)], # 每个Bundle资源需求 + strategy='SPREAD' + ) + ray.get(pg.ready()) + for i in range(rl_config.rule_reward_num_process): + rule_reward = RuleReward.options(placement_group=pg, placement_group_bundle_index=i).remote() + rule_reward.initialize.remote(reward_config, rl_config, tokenizer) + reward_list.append(rule_reward) train_ds, _, _ = build_train_valid_test_datasets( data_prefix=[actor_config.data_path, ], diff --git a/mindspeed_rl/config_cls/rl_config.py b/mindspeed_rl/config_cls/rl_config.py index de2ae915..8f2900dc 100644 --- a/mindspeed_rl/config_cls/rl_config.py +++ b/mindspeed_rl/config_cls/rl_config.py @@ -47,6 +47,7 @@ class RLConfig(BaseConfig): def __init__(self, config_dict): self.runtime_env_path = 'configs/envs/runtime_env.yaml' self.rule_reward = True + self.rule_reward_num_process=1 self.beta = 0.1 self.actor_resource = None self.reference_resource = None diff --git a/mindspeed_rl/trainer/grpo_trainer_hybrid.py b/mindspeed_rl/trainer/grpo_trainer_hybrid.py index d2aef2b4..e01a304b 100644 --- a/mindspeed_rl/trainer/grpo_trainer_hybrid.py +++ b/mindspeed_rl/trainer/grpo_trainer_hybrid.py @@ -126,11 +126,19 @@ class RayGRPOTrainer(RayBaseTrainer): self.actor_worker.generate_sequences(blocking=self.blocking) # compute rm scores. + # for reward_worker in self.reward_list: + # if isinstance(reward_worker, RayActorGroup): + # reward_worker.compute_rm_score(blocking=self.blocking) + # else: + # self.rule_reward_compute_rm_score(reward_worker, blocking=False) + + # compute rm scores. + rule_reward_compute = [] for reward_worker in self.reward_list: if isinstance(reward_worker, RayActorGroup): reward_worker.compute_rm_score(blocking=self.blocking) else: - self.rule_reward_compute_rm_score(reward_worker, blocking=False) + rule_reward_compute.append(reward_worker.compute_rm_score.remote()) # compute advantages, executed on the driver process self.compute_advantage(blocking=False) diff --git a/mindspeed_rl/utils/kwai.py b/mindspeed_rl/utils/kwai.py index 89d6466b..307c11ea 100644 --- a/mindspeed_rl/utils/kwai.py +++ b/mindspeed_rl/utils/kwai.py @@ -12,7 +12,12 @@ from .evaluation_utils.math_util import evaluate_math import time import ray from tqdm.asyncio import tqdm - +import time +import multiprocessing as mp +from multiprocessing import Process, Queue +import signal +import time +from contextlib import contextmanager def validate_response_structure(processed_str: str, task: str) -> bool: """只检查是否有标签且标签前有内容""" @@ -30,15 +35,18 @@ def validate_response_structure(processed_str: str, task: str) -> bool: return validation_passed - -def process_completion(completion, task, reference): - if task == "code": - return evaluate_code(completion, reference) - elif task == "math": - return evaluate_math(completion, str(reference)) - else: - print('task') - raise NotImplementedError +def process_completion(q, completions, tasks, references): + result = [] + for completion, task, reference in zip(completions, tasks, references): + if task == "code": + result.append([evaluate_code(completion, reference)]) + elif task == "math": + result.append([evaluate_math(completion, str(reference))]) + else: + print('task') + raise NotImplementedError + q.put(result) + return def get_format_score(validation_passed): @@ -149,27 +157,89 @@ async def parallel_evaluate_continual_async(completions, references, tasks, num_ scores.append(total_score) return scores -@ray.remote -def process_row(completion, reference, task, timeout=300.0): + +# @ray.remote +def process_row(completion, reference, task, max_num_workers=8): """ Process a single row synchronously. """ - try: - result = process_completion(completion, task, reference) - return [result] - except Exception as e: - print(f"Error processing completion in Process_Row function: {completion[:10]}, Error: {e}") - return None + length = len(completion) + timeout = length / max_num_workers * 100 + print('zyl_timeout : ',timeout) + result = multiprocess_executor(process_completion, completion, task, reference, timeout_seconds=timeout, max_num_workers=max_num_workers) + return result + +def multiprocess_executor(worker, completion, tasks, reference, timeout_seconds=30, max_num_workers=8): + if not completion: + return [] + + # 根据数据量调整进程数,保证每个进程至少有一个任务 + num_workers = min(len(completion), mp.cpu_count() - 1, max_num_workers) + batch_size = len(completion) // num_workers + + processes = [] + lengths = [] + queues = [] # 每个进程一个队列,用于按顺序接收返回结果 + + for i in range(num_workers): + start_index = i * batch_size + end_index = (i + 1) * batch_size if i < num_workers - 1 else len(completion) + batch_length = end_index - start_index + lengths.append(batch_length) + sequence_batch = completion[start_index:end_index] + answer_batch = reference[start_index:end_index] + task_batch = tasks[start_index:end_index] + print(f'zyl_task_batch = {task_batch}') + q = Queue() + queues.append(q) + p = Process(target=worker, args=(q, sequence_batch, task_batch, answer_batch)) + processes.append(p) + p.start() + + final_results = [] + for i, p in enumerate(processes): + p.join(timeout=timeout_seconds) + if p.is_alive(): + p.terminate() + # 修改打印信息,和实际返回的 0.0 一致,也可按需改为[-1] + print(f'进程 {i} 超时,返回一个大小为 {lengths[i]} 的 [0.0] 列表') + for _ in range(lengths[i]): + final_results.append([(False, 'No answer found')]) + else: + try: + # 从对应的队列中获取返回值 + res = queues[i].get_nowait() + final_results.extend(res) + except Exception: + for _ in range(lengths[i]): + final_results.append([(False, 'No answer found')]) + return final_results def sequential_evaluate_continual(completions, references, tasks): """ Evaluate rows sequentially without concurrency. """ + # num_workers = 1 scores = [] futures = [] - for completion, reference, task in zip(completions, references, tasks): - futures.append(process_row.remote(completion, reference, task )) - results = ray.get(futures) + temp_list = [] + # batch_size = len(completions) // num_workers + + # for i in range(num_workers): + # start_index = i * batch_size + # end_index = (i + 1) * batch_size if i < num_workers - 1 else len(completions) + # batch_length = end_index - start_index + # completion = completions[start_index:end_index] + # reference = references[start_index:end_index] + # task = tasks[start_index:end_index] + # futures.append(process_row.remote(completion, reference, task, 4)) + + # results = ray.get(futures) + results = process_row(completions, references, tasks) + # print('========================================') + # print(f'zyl_________-111111111111 = {results}') + # # results = [item for sublist in results for item in sublist] + # # print(f'zyl_________-111111111111 = {results}') for completion, reference, task, result in zip(completions, references, tasks, results): validation_passed = validate_response_structure(completion, task) @@ -206,8 +276,6 @@ def sequential_evaluate_continual(completions, references, tasks): return scores - - def compute_score(queue, sequences, answers, tasks, *kwargs): do_print = True # completions = [completions] @@ -223,7 +291,6 @@ def compute_score(queue, sequences, answers, tasks, *kwargs): try: res = sequential_evaluate_continual(sequences, answers, tasks) print("res:", res[0]) - # return res except asyncio.TimeoutError as e: print('Global timeout in reward computing! Setting all as 0.5.') res = [0.02 for _ in range(len(sequences))] -- Gitee