From 0c8a2b09878d99cc50730084f351e14aa400c9d5 Mon Sep 17 00:00:00 2001 From: linqihong Date: Thu, 10 Apr 2025 20:14:04 +0800 Subject: [PATCH 1/4] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=8B=92=E7=BB=9D?= =?UTF-8?q?=E9=87=87=E6=A0=B7=E7=AE=97=E6=B3=95=E5=8F=8A=E5=85=B6ut?= =?UTF-8?q?=E3=80=81st?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cli/rejection_sampling.py | 189 ++++++++++++++++++ .../config_cls/rejection_sampling_config.py | 82 ++++++++ mindspeed_rl/datasets/infer_dataset.py | 130 ++++++++++++ mindspeed_rl/models/vllm_gen_rm.py | 146 ++++++++++++++ mindspeed_rl/models/vllm_infer.py | 94 +++++++++ mindspeed_rl/utils/rejection_sampler.py | 61 ++++++ requirements.txt | 31 +-- tests/st/rejection_sampling/rj_qwen25_7b.yaml | 25 +++ .../test_module_entry_rejection_sampling.sh | 9 + tests/st/rejection_sampling/test_rj.py | 123 ++++++++++++ .../datasets/test_prompt_gt_answer_dataset.py | 54 +++++ tests/ut/utils/test_rejection_sampling.py | 69 +++++++ 12 files changed, 998 insertions(+), 15 deletions(-) create mode 100644 cli/rejection_sampling.py create mode 100644 mindspeed_rl/config_cls/rejection_sampling_config.py create mode 100644 mindspeed_rl/datasets/infer_dataset.py create mode 100644 mindspeed_rl/models/vllm_gen_rm.py create mode 100644 mindspeed_rl/models/vllm_infer.py create mode 100644 mindspeed_rl/utils/rejection_sampler.py create mode 100644 tests/st/rejection_sampling/rj_qwen25_7b.yaml create mode 100644 tests/st/rejection_sampling/test_module_entry_rejection_sampling.sh create mode 100644 tests/st/rejection_sampling/test_rj.py create mode 100644 tests/ut/datasets/test_prompt_gt_answer_dataset.py create mode 100644 tests/ut/utils/test_rejection_sampling.py diff --git a/cli/rejection_sampling.py b/cli/rejection_sampling.py new file mode 100644 index 0000000..79400b9 --- /dev/null +++ b/cli/rejection_sampling.py @@ -0,0 +1,189 @@ +# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. +import gc +from typing import Dict, Any + +import hydra +import math +import ray +import jsonlines +from vllm.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel + +if not ray.is_initialized(): + ray.init() + +from mindspeed_rl import get_tokenizer +from mindspeed_rl.models.vllm_infer import RayInferActor +from mindspeed_rl.utils.loggers import Loggers +from mindspeed_rl.config_cls.rejection_sampling_config import RejectionSamplingConfig +from mindspeed_rl.datasets.infer_dataset import load_datasets, PromptGtAnswerDataset +from mindspeed_rl.models.vllm_gen_rm import RayGenRM +from mindspeed_rl.utils.rejection_sampler import filtering_by_rules, rejection_sampling_processor + +logger = Loggers('rejection_sampling') + + +def clean_up(): + destroy_model_parallel() + destroy_distributed_environment() + gc.collect() + + +def parse_config(config: Dict): + """ + 解析配置,提取拒绝采样配置、vllm推理配置。 + + config: 输入的全局配置字典。 + return: 拒绝采样配置、vllm推理配置。 + """ + rj_config = RejectionSamplingConfig(config.get("rj_config")) + infer_config = rj_config.sampling_config + return rj_config, infer_config + + +def batch_generate_vllm(rj_config: RejectionSamplingConfig, infer_config: Dict[str, Any]) -> None: + """ + 用ray启动多个vllm实例进行批量推理,保存prompt、response、gt_answer + """ + tokenizer = get_tokenizer(rj_config.model_path) + + prompts_data = load_datasets( + rj_config.dataset_path, + rj_config.probabilities, + seed=infer_config.seed + ) + + if rj_config.iter is None or rj_config.iter < 0: + prompts_data = prompts_data.select(range(min(rj_config.max_samples, len(prompts_data)))) + else: + start_idx = rj_config.iter * rj_config.rollout_batch_size + end_idx = start_idx + rj_config.rollout_batch_size + prompts_data = prompts_data.select(range(start_idx, min(end_idx, len(prompts_data)))) + + data = PromptGtAnswerDataset(prompts_data, tokenizer, rj_config.map_keys, + rj_config.apply_chat_template, input_template=rj_config.input_template) + data = list(data) + batches = [] + for i in range(0, len(data), math.ceil(len(data) / rj_config.num_vllm_instances)): + batches.append(data[i:i + math.ceil(len(data) / rj_config.num_vllm_instances)]) + + prompts = [[item["prompt"] for item in batch] for batch in batches] + gt_answers = [[item["gt_answer"] for item in batch] for batch in batches] + + vllm_instances = [] + for _ in range(rj_config.num_vllm_instances): + vllm_instances.append( + RayInferActor.options(resources={"NPU": infer_config.infer_tensor_parallel_size}) + .remote( + model_path=rj_config.model_path, + infer_tensor_parallel_size=infer_config.infer_tensor_parallel_size, + seed=infer_config.seed, + max_num_seqs=infer_config.max_num_seqs, + max_new_tokens=infer_config.max_new_tokens, + top_p=infer_config.top_p, + top_k=infer_config.top_k, + temperature=infer_config.temperature, + repetition_penalty=infer_config.repetition_penalty, + ) + ) + + for vllm_ins in vllm_instances: + vllm_ins.init_vLLM.remote() + + outputs = ray.get([vllm_ins.generate_sequences.remote(prompts[i] * rj_config.best_of_n) for i, vllm_ins in + enumerate(vllm_instances)]) + + output_dataset = [] + for i in range(rj_config.num_vllm_instances): + for output, gt_answer in zip(outputs[i], gt_answers[i] * rj_config.best_of_n): + prompt = output.prompt + output = output.outputs[0].text + output_dataset.append({"prompt": prompt, "response": output, "gt_answer": gt_answer}) + + with jsonlines.open(rj_config.output_path, mode="w") as writer: + writer.write_all(output_dataset) + + del vllm_instances + + +def batch_rm_rejection_sampling(rj_config: RejectionSamplingConfig, infer_config: Dict[str, Any]) -> None: + """ + 用ray启动多个生成式奖励模型实例进行拒绝采样 + """ + data = load_datasets( + rj_config.dataset_path, + probabilities=None, + seed=infer_config.seed, + max_count=rj_config.max_samples + ) + + input_key = rj_config.map_keys.get("prompt") + gt_answer_key = rj_config.map_keys.get("gt_answer") + response_key = rj_config.map_keys.get("response") + if rj_config.filter_by_rules: + data = filtering_by_rules(data, gt_answer_key, response_key, + rj_config.verifier_parallel, + rj_config.verifier_function, + rj_config.verifier_weight, + rj_config.verifier_timeout, + rj_config.accept_score) + + batches = [] + for i in range(0, len(data), math.ceil(len(data) / rj_config.num_vllm_instances)): + batches.append(data[i:i + math.ceil(len(data) / rj_config.num_vllm_instances)]) + + rm_instances = [] + for _ in range(rj_config.num_vllm_instances): + rm_instances.append( + RayGenRM.options(resources={"NPU": infer_config.infer_tensor_parallel_size}) + .remote( + model_path=rj_config.model_path, + infer_tensor_parallel_size=infer_config.infer_tensor_parallel_size, + seed=infer_config.seed, + max_num_seqs=infer_config.max_num_seqs, + max_new_tokens=infer_config.max_new_tokens, + top_p=infer_config.top_p, + top_k=infer_config.top_k, + temperature=infer_config.temperature, + repetition_penalty=infer_config.repetition_penalty, + map_keys=rj_config.map_keys, + use_ground_truth_answer=rj_config.use_ground_truth_answer, + ) + ) + + for rm_ins in rm_instances: + rm_ins.init_vLLM.remote() + + outputs = ray.get([rm_ins.run_rm_judge.remote(batches[i]) for i, rm_ins in enumerate(rm_instances)]) + + outputs = [item for output in outputs for item in output] + if rj_config.save_rm_judgement: + with jsonlines.open(rj_config.rm_judgement_path, mode="w") as writer: + writer.write_all(outputs) + + res = rejection_sampling_processor(outputs, input_key, response_key) + + with jsonlines.open(rj_config.output_path, mode="w") as writer: + writer.write_all(res) + + print(f"Processing complete and data saved to '{rj_config.output_path}'.") + + del rm_instances + + +@hydra.main(config_path='../configs', config_name='batch_generate_vllm_qwen25_7b', version_base=None) +def main(config): + rj_config, infer_config = parse_config(config) + + if rj_config.task == "generate_vllm": + batch_generate_vllm(rj_config, infer_config) + elif rj_config.task == "rejection_sampling": + batch_rm_rejection_sampling(rj_config, infer_config) + else: + logger.info( + "Invalid or missing '--task' argument. Please specify either 'vllm_generate' or 'rejection_sampling'.") + + clean_up() + + +if __name__ == '__main__': + main() diff --git a/mindspeed_rl/config_cls/rejection_sampling_config.py b/mindspeed_rl/config_cls/rejection_sampling_config.py new file mode 100644 index 0000000..8e01643 --- /dev/null +++ b/mindspeed_rl/config_cls/rejection_sampling_config.py @@ -0,0 +1,82 @@ +import json +from mindspeed_rl.config_cls.base_config import BaseConfig + + +class RejectionSamplingConfig(BaseConfig): + ''' + Initialize model configuration from the provided config dictionary. + All instance attributes are initialized using the dictionary keys. + + task: The specific task name,generate_vllm or rejection_sampling, (default: None) + model_path: Path of the model (default: "") + dataset_path: Path of the dataset (default: "") + output_path: Path to save the model (default: "") + save_rm_judgement: Enable saving reward model judgements (default: False) + rm_judgement_path: Path to save reward model judgements (default: "") + map_keys: Dataset keys mapping, (default'{"prompt":"input","gt_answer":"gt_answer","response":""}') + probabilities: Mixing probabilities of multiple data sets (default: None) + input_template: User-defined template (default: None) + apply_chat_template: Whether to use the template provided by the model. + + iter: Data slice, ranging from iter * rollout_batch_size: (iter + 1) * rollout_batch_size (default: 0) + rollout_batch_size: Number of rollout (default: 100) + best_of_n: Number of responses generated by each prompt (default: 8) + max_samples: Maximum number of data samples. If iter is not set, this parameter is used. (default: 5000000) + use_ground_truth_answer: Whether to use the ground truth answer from the dataset. (default: False) + num_vllm_instances: The number of running vLLM instances. (default: 1) + infer_tensor_parallel_size: Tensor parallel size during evaluation. (default: 1) + + filter_by_rules: Whether to use the rules to filter the dataset. (default: False) + accept_score: Threshold for the rule score used in filtering data. (default: 0.5) + + max_new_tokens: Length of the output generated text. (default: 1024) + max_num_seqs: Maximum number of sequences to process simultaneously of vLLM definition. (default: 64) + top_p: The cumulative probability threshold for nucleus sampling. (default: 0.8) + top_k: The number of highest - probability tokens to consider for sampling. (default: 5) + temperature: Controls the randomness of predictions by scaling the logits before applying softmax. (default: 0.5) + repetition_penalty: Control the ratio of repeat content, it will penalize repeating generation. Default is 1 represents no penalty. + seed: Random seeds. (default: 1234) + ''' + + def __init__(self, config_dict): + self.task = None + self.model_path = "" + self.dataset_path = "" + self.output_path = "" + self.rm_judgement_path = "" + self.save_rm_judgement = False + self.map_keys = '{"prompt":"input","gt_answer":"gt_answer","response":""}' + self.probabilities = None + self.input_template = None + self.apply_chat_template = False + + self.iter = 0 + self.rollout_batch_size = 100 + self.best_of_n = 8 + self.max_samples = 5000000 + self.use_ground_truth_answer = False + + self.infer_tensor_parallel_size = 1 + self.num_vllm_instances = 1 + + self.filter_by_rules = False + self.verifier_function = ["acc", ] + self.verifier_weight = [1.0, ] + self.verifier_parallel = 4 + self.verifier_timeout = 120 + self.accept_score = 0.5 + + self.sampling_config = { + "infer_tensor_parallel_size": 1, # tensor并行数 + "max_new_tokens": 1024, # 生成输出的最大 token 数量 + "max_num_seqs": 64, # 同时处理的最大序列数 + "top_p": 0.8, # 核采样的累积概率阈值 + "top_k": 5, # 采样时考虑的最高概率 token 的数量 + "temperature": 0.5, # 控制预测随机性的温度参数 + "repetition_penalty": 1.0, # 重复惩罚系数 + "seed": 1234 # 随机数种子 + } + + config_dict['map_keys'] = json.loads(config_dict['map_keys'] if 'map_keys' in config_dict.keys() else + self.map_keys) + self.update(config_dict) diff --git a/mindspeed_rl/datasets/infer_dataset.py b/mindspeed_rl/datasets/infer_dataset.py new file mode 100644 index 0000000..dd1d375 --- /dev/null +++ b/mindspeed_rl/datasets/infer_dataset.py @@ -0,0 +1,130 @@ +# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. +# Copyright (c) 2023-2025 The OpenRLHF Authors. + +import os +from typing import Dict, Any, Callable + +from datasets import interleave_datasets, load_dataset, load_from_disk +from datasets.combine import DatasetType +from torch.utils.data import Dataset +from tqdm import tqdm + + +def preprocess_data(data: Dict[str, Any], input_template: str = None, input_key: str = "input", + apply_chat_template: Callable = None) -> str: + if apply_chat_template: + chat = data[input_key] + if isinstance(chat, str): + chat = [{"role": "user", "content": chat}] + prompt = apply_chat_template(chat, tokenize=False, add_generation_prompt=True) + else: + prompt = data[input_key] if input_key in data.keys() else None + if input_template: + prompt = input_template.format(prompt) + return prompt + + +def load_single_dataset(dataset: str, data_dir: str = None): + """ + Load a single dataset given its path or identifier. + Supports loading from different sources: local files, remote datasets, or directories. + """ + dataset = dataset.strip() + ext = os.path.splitext(dataset)[-1] + + # local python script + if ext == ".py" or ( + os.path.isdir(dataset) and os.path.exists(os.path.join(dataset, f"{os.path.basename(dataset)}.py"))): + data = load_dataset(dataset, trust_remote_code=True) + # local text file + elif ext in [".json", ".jsonl", ".csv", ".parquet"]: + ext = ext.lower().strip(".") + if ext == "jsonl": + ext = "json" + data = load_dataset(ext, data_files=dataset) + # remote/local folder or common file + else: + data = load_dataset(dataset, data_dir=data_dir) + return data + + +def load_datasets( + datasets_path: str, + probabilities: str = None, + seed: int = 42, + max_count: int = -1, + train_split: str = "train" +) -> DatasetType: + datasets_path = datasets_path.split(",") + if probabilities is None: + probabilities = [1.0] * len(datasets_path) + else: + probabilities = list(map(float, probabilities.split(","))) + assert len(probabilities) == len(datasets_path) + + train_data_list = [] + + for i, dataset in enumerate(datasets_path): + print(f"dataset: {dataset}") + + data_dir = dataset.split("@")[1].strip() if "@" in dataset else "" + dataset_name = dataset.split("@")[0].strip() + + # Call load_single_dataset to load the dataset + data = load_single_dataset(dataset_name, data_dir) + + # Select train data + if train_split and train_split in data: + max_count = min(max_count, len(data[train_split])) if max_count > 0 else len(data[train_split]) + train_data = data[train_split].select(range(max_count)) + else: + max_count = min(max_count, len(data)) if max_count > 0 else len(data) + train_data = data.select(range(max_count)) + train_data_list.append(train_data) + + # Merge datasets + train_dataset = interleave_datasets( + train_data_list, + probabilities=probabilities, + seed=seed, + ) + + return train_dataset + + +class PromptGtAnswerDataset(Dataset): + ''' + Base configuration class. + ''' + def __init__( + self, + dataset: DatasetType, + tokenizer: Callable, + map_keys: Dict[str, str], + apply_chat_template: bool = False, + input_template: str = None, + ) -> None: + super().__init__() + self.tokenizer = tokenizer + + self.input_template = input_template + input_key = map_keys.get("prompt", None) + gt_answer_key = map_keys.get("gt_answer", None) + + if apply_chat_template: + apply_chat_template = self.tokenizer.apply_chat_template + + self.prompts = [] + self.gt_answers = [] + for data in tqdm(dataset, desc="Preprocessing data"): + prompt = preprocess_data(data, input_template, input_key, apply_chat_template) + gt_answer = preprocess_data(data, input_key=gt_answer_key) + self.prompts.append(prompt) + self.gt_answers.append(gt_answer) + + def __len__(self): + length = len(self.prompts) + return length + + def __getitem__(self, idx): + return {"prompt": self.prompts[idx], "gt_answer": self.gt_answers[idx]} diff --git a/mindspeed_rl/models/vllm_gen_rm.py b/mindspeed_rl/models/vllm_gen_rm.py new file mode 100644 index 0000000..b2b1dfb --- /dev/null +++ b/mindspeed_rl/models/vllm_gen_rm.py @@ -0,0 +1,146 @@ +# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. +import re +from typing import List, Dict, Any +import ray + +from mindspeed_rl.models.vllm_infer import InferActor + + +class GenRM(InferActor): + """ + Generative reward model using VLLM for inference. + + Args: + model_path: Path of the model. + infer_tensor_parallel_size: Tensor parallel size during evaluation. (default: 1) + seed: Random seeds. (default: 1234) + max_num_seqs: Maximum number of sequences to process simultaneously of vLLM definition. (default: 64) + max_new_tokens: Length of the output generated text. (default: 2048) + top_p: The cumulative probability threshold for nucleus sampling. (default: 0.8) + top_k: The number of highest - probability tokens to consider for sampling. (default: 10) + temperature: Controls the randomness of predictions by scaling the logits before applying softmax. (default: 1.0) + repetition_penalty: Control the ratio of repeat content, it will penalize repeating generation. Default is 1 represents no penalty. + map_keys: Dataset keys mapping, (default: None) + Whether to use the ground truth answer from the dataset. (default: False) + """ + def __init__(self, + model_path: str, + infer_tensor_parallel_size: int = 1, + seed: int = 1234, + max_num_seqs: int = 64, + max_new_tokens: int = 2048, + top_p: float = 0.8, + top_k: int = 10, + temperature: float = 1.0, + repetition_penalty: float = 1.0, + map_keys: Dict[str, str] = None, + use_ground_truth_answer: bool = False, + ): + super(GenRM, self).__init__(model_path, + infer_tensor_parallel_size, + seed, + max_num_seqs, + max_new_tokens, + top_p, + top_k, + temperature, + repetition_penalty) + self.input_key = map_keys.get("prompt") + self.response_key = map_keys.get("response") + self.gt_answer_key = map_keys.get("gt_answer") + self.use_ground_truth_answer = use_ground_truth_answer + + def process_row_function(self, example): + """ + 处理每行数据,包括提取需要字段,加载生成式奖励模型的评估模版 + """ + prompt_text = example[self.input_key] + response_text = example[self.response_key] + if self.use_ground_truth_answer: + gt_answer_text = example[self.gt_answer_key] + else: + gt_answer_text = None + + judgement_prompt = self.apply_gen_rm_template(prompt_text, response_text, gt_answer_text) + example['judgement_prompt'] = judgement_prompt + return example + + def get_gen_rm_rewards(self, input_data: List[Dict[str, Any]]): + """ + 对回复进行生成式奖励模型评判,解析评判文本,得到reward分数 + """ + judgement_prompts = [item['judgement_prompt'] for item in input_data] + judgements = self.generate_sequences(judgement_prompts) + + output_dataset = [] + for example, judgement in zip(input_data, judgements): + example["judgement"] = judgement.outputs[0].text + judgement_parsing = re.findall(r'(\s*-?\d+(?:\.\d+)?\s*)', example["judgement"]) + if judgement_parsing: + example["reward"] = float(judgement_parsing[0]) + else: + example["reward"] = -1 + output_dataset.append(example) + return output_dataset + + def apply_gen_rm_template(self, prompt_text: str, response_text: str, ground_truth_answer_text: str) -> str: + """ + 加载生成式奖励模型的评估模版,包括有无正确答案的两套模板 + """ + if ground_truth_answer_text: + full_input = f"""<|im_start|>你是一个判别推理正确性的专家。 + 问题[PROMPT]: {prompt_text} + 正确答案[GROUND TRUTH]: {ground_truth_answer_text} + 回复[RESPONSE]: {response_text} + 任务目标:根据给定的问题[PROMPT]和正确答案[GROUND TRUTH],评估回复[RESPONSE]质量。重点考虑回复结果的正确性,\ + 其次考虑语言一致性、格式正确性、推理合理性、语句重复冗余性,用简洁的文字说明原因。\ + 最后给出0到1之间的分数,分数以score here的形式给出。<|im_end|> + <|im_start|>assistant + """ + else: + full_input = f"""<|im_start|>你是一个判别推理正确性的专家。 + 问题[PROMPT]: {prompt_text} + 回复[RESPONSE]: {response_text} + 任务目标:根据给定的问题[PROMPT],评估回复[RESPONSE]质量,考虑回复结果的正确性、语言一致性、\ + 格式正确性、推理合理性、回复无害性、语句重复冗余性,用简洁的文字说明原因。\ + 最后给出0到1之间的分数,分数以score here的形式给出。<|im_end|> + <|im_start|>assistant + """ + return full_input + + def run_rm_judge(self, input_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + 运行生成式奖励模型评判函数 + """ + data = [] + for item in input_data: + data.append(self.process_row_function(item)) + res = self.get_gen_rm_rewards(data) + return res + + +@ray.remote +class RayGenRM(GenRM): + def __init__(self, + model_path: str, + infer_tensor_parallel_size: int = 1, + seed: int = 1234, + max_num_seqs: int = 64, + max_new_tokens: int = 2048, + top_p: float = 0.8, + top_k: int = 10, + temperature: float = 1.0, + repetition_penalty: float = 1.0, + map_keys: Dict[str, str] = None, + use_ground_truth_answer: bool = False): + super().__init__(model_path, + infer_tensor_parallel_size, + seed, + max_num_seqs, + max_new_tokens, + top_p, + top_k, + temperature, + repetition_penalty, + map_keys, + use_ground_truth_answer) diff --git a/mindspeed_rl/models/vllm_infer.py b/mindspeed_rl/models/vllm_infer.py new file mode 100644 index 0000000..76d371a --- /dev/null +++ b/mindspeed_rl/models/vllm_infer.py @@ -0,0 +1,94 @@ +# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. +from typing import List + +import ray +from vllm import LLM, SamplingParams + + +class InferActor(): + """ + Infer Actor using VLLM for inference. + + Args: + model_path: Path of the model. + infer_tensor_parallel_size: Tensor parallel size during evaluation. (default: 1) + seed: Random seeds. (default: 1234) + max_num_seqs: Maximum number of sequences to process simultaneously of vLLM definition. (default: 64) + max_new_tokens: Length of the output generated text. (default: 2048) + top_p: The cumulative probability threshold for nucleus sampling. (default: 0.8) + top_k: The number of highest - probability tokens to consider for sampling. (default: 10) + temperature: Controls the randomness of predictions by scaling the logits before applying softmax. (default: 1.0) + repetition_penalty: Control the ratio of repeat content, it will penalize repeating generation. Default is 1 represents no penalty. + """ + def __init__(self, + model_path: str, + infer_tensor_parallel_size: int = 1, + seed: int = 1234, + max_num_seqs: int = 64, + max_new_tokens: int = 2048, + top_p: float = 0.8, + top_k: int = 10, + temperature: float = 1.0, + repetition_penalty: float = 1.0, + ): + self.llm = None + self.sampling_params = None + self.model_path = model_path + self.infer_tensor_parallel_size = infer_tensor_parallel_size + self.seed = seed + self.max_num_seqs = max_num_seqs + self.max_new_tokens = max_new_tokens + self.top_p = top_p + self.top_k = top_k + self.temperature = temperature + self.repetition_penalty = repetition_penalty + + def init_vLLM(self): + """ + 初始化离线推理接口 vLLM.LLM() + """ + self.llm = LLM( + model=self.model_path, + tensor_parallel_size=self.infer_tensor_parallel_size, + trust_remote_code=True, + seed=self.seed, + max_num_seqs=self.max_num_seqs + ) + + # Create a sampling params object. + self.sampling_params = SamplingParams( + max_tokens=self.max_new_tokens, + top_p=self.top_p, + top_k=self.top_k, + temperature=self.temperature, + repetition_penalty=self.repetition_penalty + ) + + def generate_sequences(self, prompts: List[str]): + """generate sequences via vLLM offline inference""" + outputs = self.llm.generate(prompts, self.sampling_params) + + return outputs + + +@ray.remote +class RayInferActor(InferActor): + def __init__(self, + model_path: str, + infer_tensor_parallel_size: int = 1, + seed: int = 1234, + max_num_seqs: int = 64, + max_new_tokens: int = 2048, + top_p: float = 0.8, + top_k: int = 10, + temperature: float = 1.0, + repetition_penalty: float = 1.0,): + super().__init__(model_path, + infer_tensor_parallel_size, + seed, + max_num_seqs, + max_new_tokens, + top_p, + top_k, + temperature, + repetition_penalty) diff --git a/mindspeed_rl/utils/rejection_sampler.py b/mindspeed_rl/utils/rejection_sampler.py new file mode 100644 index 0000000..998057a --- /dev/null +++ b/mindspeed_rl/utils/rejection_sampler.py @@ -0,0 +1,61 @@ +# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. +from typing import List, Dict, Any +from tqdm import tqdm + + +def filtering_by_rules(input_data: List[Dict[str, Any]], + gt_answer_key: str, + response_key: str, + verifier_parallel: int = 4, + verifier_function: List[str] = ["acc"], + verifier_weight: List[float] = [1.0], + verifier_timeout: int = 300, + accept_score: float = 0.5 + ): + """ + 拒绝采样function,对同一个prompt保留得分最高的回复 + """ + from types import SimpleNamespace + from mindspeed_rl.models.rule_verifier import verifier + labels = {} + labels["labels"] = [data[gt_answer_key] for data in input_data] + responses = [data[response_key] for data in input_data] + + config = SimpleNamespace() + config.verifier_parallel = verifier_parallel + config.verifier_function = verifier_function + config.verifier_weight = verifier_weight + config.verifier_timeout = verifier_timeout + + scores, reward = verifier(responses, labels, config) + + output_dataset = [] + for example, score in zip(input_data, scores): + example["rule_score"] = score + if score >= accept_score: + example['select'] = True + else: + example['select'] = False + output_dataset.append(example) + output_dataset = list(filter(lambda item: item["select"], output_dataset)) + return output_dataset + + +def rejection_sampling_processor(data: List[Dict[str, Any]], input_key: str, response_key: str): + """ + 拒绝采样function,对同一个prompt保留得分最高的回复 + """ + out = {} + for item in tqdm(data, desc="Rejection Sampling process...."): + prompt = item[input_key] + response = item[response_key] + reward = item["reward"] + + if reward > 0: + if prompt not in out: + out[prompt] = {"response": response, "reward": reward} + elif reward > out[prompt]["reward"]: + out[prompt]["reward"] = reward + out[prompt]["response"] = response + + return [{"prompt": k, "response": v["response"], "reward": v["reward"]} for k, v in out.items()] diff --git a/requirements.txt b/requirements.txt index 76f1056..556ddae 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,16 @@ -codetiming -datasets -hydra-core==1.3.2 -latex2sympy2 -numpy==1.26.4 -omegaconf -pybind11 -ray==2.42.1 -regex -sympy -tensorboard -tensordict -transformers==4.48.2 -word2number -wandb +codetiming +datasets +hydra-core==1.3.2 +jsonlines +latex2sympy2 +numpy==1.26.4 +omegaconf +pybind11 +ray==2.42.1 +regex +sympy +tensorboard +tensordict +transformers==4.48.2 +word2number +wandb diff --git a/tests/st/rejection_sampling/rj_qwen25_7b.yaml b/tests/st/rejection_sampling/rj_qwen25_7b.yaml new file mode 100644 index 0000000..ef94296 --- /dev/null +++ b/tests/st/rejection_sampling/rj_qwen25_7b.yaml @@ -0,0 +1,25 @@ +rj_config: + model_path: /home/cq/code/MindSpeed-LLM-lxy/MindSpeed-LLM-dl/model_from_hf/Qwen2.5-7B-Instruct + dataset_path: /data/for_dt/datasets/rejection_sampling/generate_output_math.jsonl + task: rejection_sampling + map_keys: '{"prompt":"prompt", "gt_answer":"gt_answer", "response":"response"}' + use_ground_truth_answer: True + num_vllm_instances: 2 + filter_by_rules: true + verifier_function: ["acc", ] + verifier_weight: [1.0, ] + verifier_parallel: 4 + verifier_timeout: 120 + accept_score: -1.0 + + sampling_config: + infer_tensor_parallel_size: 2 + max_new_tokens: 512 + max_num_seqs: 64 + temperature: 0.3 + repetition_penalty: 1.0 + top_p: 0.3 + top_k: 3 + seed: 1234 + +#/data/for_dt/weights/Qwen2.5-7B-Instruct: \ No newline at end of file diff --git a/tests/st/rejection_sampling/test_module_entry_rejection_sampling.sh b/tests/st/rejection_sampling/test_module_entry_rejection_sampling.sh new file mode 100644 index 0000000..8f1e00e --- /dev/null +++ b/tests/st/rejection_sampling/test_module_entry_rejection_sampling.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +# 获取脚本的绝对路径 +SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd) +export PYTHONPATH=$SCRIPT_DIR/../../..:$PYTHONPATH + +ray stop + +python $SCRIPT_DIR/test_rj.py --config-name rj_qwen25_7b.yaml \ No newline at end of file diff --git a/tests/st/rejection_sampling/test_rj.py b/tests/st/rejection_sampling/test_rj.py new file mode 100644 index 0000000..2ed5451 --- /dev/null +++ b/tests/st/rejection_sampling/test_rj.py @@ -0,0 +1,123 @@ +# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. +import gc +import time +from typing import Dict, Any + +import hydra +import math +import ray +from vllm.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel + +if not ray.is_initialized(): + ray.init() + +from mindspeed_rl.datasets.infer_dataset import load_datasets +from mindspeed_rl.config_cls.rejection_sampling_config import RejectionSamplingConfig +from mindspeed_rl.models.vllm_gen_rm import RayGenRM +from mindspeed_rl.utils.loggers import Loggers +from mindspeed_rl.utils.rejection_sampler import rejection_sampling_processor, filtering_by_rules + +logger = Loggers('rejection_sampling') + + +def clean_up(): + destroy_model_parallel() + destroy_distributed_environment() + gc.collect() + + +def parse_config(config: Dict): + """ + 解析配置,提取拒绝采样配置、vllm推理配置。 + + config: 输入的全局配置字典。 + return: 拒绝采样配置、vllm推理配置。 + """ + rj_config = RejectionSamplingConfig(config.get("rj_config")) + infer_config = rj_config.sampling_config + return rj_config, infer_config + + +def batch_rm_rejection_sampling(rj_config: RejectionSamplingConfig, infer_config: Dict[str, Any]) -> None: + """ + 用ray启动多个生成式奖励模型实例进行拒绝采样 + """ + data = load_datasets( + rj_config.dataset_path, + probabilities=None, + seed=infer_config.seed, + max_count=rj_config.max_samples + ) + + input_key = rj_config.map_keys.get("prompt") + gt_answer_key = rj_config.map_keys.get("gt_answer") + response_key = rj_config.map_keys.get("response") + if rj_config.filter_by_rules: + data = filtering_by_rules(data, gt_answer_key, response_key, + rj_config.verifier_parallel, + rj_config.verifier_function, + rj_config.verifier_weight, + rj_config.verifier_timeout, + rj_config.accept_score) + + batches = [] + for i in range(0, len(data), math.ceil(len(data) / rj_config.num_vllm_instances)): + batches.append(data[i:i + math.ceil(len(data) / rj_config.num_vllm_instances)]) + + rm_instances = [] + for _ in range(rj_config.num_vllm_instances): + rm_instances.append( + RayGenRM.options(resources={"NPU": infer_config.infer_tensor_parallel_size}) + .remote( + model_path=rj_config.model_path, + infer_tensor_parallel_size=infer_config.infer_tensor_parallel_size, + seed=infer_config.seed, + max_num_seqs=infer_config.max_num_seqs, + max_new_tokens=infer_config.max_new_tokens, + top_p=infer_config.top_p, + top_k=infer_config.top_k, + temperature=infer_config.temperature, + repetition_penalty=infer_config.repetition_penalty, + map_keys=rj_config.map_keys, + use_ground_truth_answer=rj_config.use_ground_truth_answer, + ) + ) + + for rm_ins in rm_instances: + rm_ins.init_vLLM.remote() + + outputs = ray.get([rm_ins.run_rm_judge.remote(batches[i]) for i, rm_ins in enumerate(rm_instances)]) + outputs = [item for output in outputs for item in output] + + res = rejection_sampling_processor(outputs, input_key, response_key) + + del rm_instances + + return res + + +@hydra.main(config_path='./', config_name='rj_qwen25_7b', version_base=None) +def main(config): + rj_config, infer_config = parse_config(config) + a = time.time() + + if rj_config.task == "rejection_sampling": + data_path = None + if 'rejection_sampling' in rj_config.dataset_path: + data_path = rj_config.dataset_path + + try: + if data_path: + result = batch_rm_rejection_sampling(rj_config, infer_config) + logger.info('\n rejection sampling result length: {}'.format(len(result))) + logger.info('\n rejection sampling result: {}'.format(result)) + except Exception as e: + logger.info(e) + + logger.info(f'Rejection Sampling Running Time:, {time.time() - a}') + + clean_up() + + +if __name__ == '__main__': + main() diff --git a/tests/ut/datasets/test_prompt_gt_answer_dataset.py b/tests/ut/datasets/test_prompt_gt_answer_dataset.py new file mode 100644 index 0000000..f916c36 --- /dev/null +++ b/tests/ut/datasets/test_prompt_gt_answer_dataset.py @@ -0,0 +1,54 @@ +# coding=utf-8 +# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. + +from datasets import Dataset + +from mindspeed_rl.datasets.infer_dataset import PromptGtAnswerDataset +from mindspeed_rl.utils.tokenizer import get_tokenizer + +from tests.test_tools.dist_test import DistributedTest + + +class TestPromptGtAnswerDataset(DistributedTest): + world_size = 1 + + def test_prompt_gt_answer_dataset(self): + tokenizer_directory = '/data/models/llama2-7b' + hf_tokenizer = get_tokenizer(tokenizer_directory) + map_keys = {"prompt": "input", "gt_answer": "gt_answer"} + input_data = Dataset.from_dict({"input": ["AAA", "BBB"], "gt_answer": ["A", "B"]}) + + dataset = PromptGtAnswerDataset( + dataset=input_data, + tokenizer=hf_tokenizer, + map_keys=map_keys, + apply_chat_template=False, + input_template=None, + ) + assert len(dataset) == 2, "The __len__ method of the PromptGtAnswerDataset failed!" + + assert dataset[0]["prompt"] == "AAA", "The __get_item__ method of the PromptGtAnswerDataset failed!" + assert dataset[0]["gt_answer"] == "A", "The __get_item__ method of the PromptGtAnswerDataset failed!" + assert dataset[1]["prompt"] == "BBB", "The __get_item__ method of the PromptGtAnswerDataset failed!" + assert dataset[1]["gt_answer"] == "B", "The __get_item__ method of the PromptGtAnswerDataset failed!" + + def test_input_template(self): + tokenizer_directory = '/data/models/llama2-7b' + hf_tokenizer = get_tokenizer(tokenizer_directory) + map_keys = {"prompt": "input", "gt_answer": "gt_answer"} + input_data = Dataset.from_dict({"input": ["AAA", "BBB"], "gt_answer": ["A", "B"]}) + input_template = "<|im_start|>You are a helpful assitant.<|im_end|>User:{} Assitant:" + dataset = PromptGtAnswerDataset( + dataset=input_data, + tokenizer=hf_tokenizer, + map_keys=map_keys, + apply_chat_template=False, + input_template=input_template, + ) + assert len(dataset) == 2, "The __len__ method of the PromptGtAnswerDataset failed!" + assert dataset[0]["prompt"] == "<|im_start|>You are a helpful assitant.<|im_end|>User:AAA Assitant:", \ + "The __get_item__ method of the PromptGtAnswerDataset failed!" + assert dataset[0]["gt_answer"] == "A", "The __get_item__ method of the PromptGtAnswerDataset failed!" + assert dataset[1]["prompt"] == "<|im_start|>You are a helpful assitant.<|im_end|>User:BBB Assitant:", \ + "The __get_item__ method of the PromptGtAnswerDataset failed!" + assert dataset[1]["gt_answer"] == "B", "The __get_item__ method of the PromptGtAnswerDataset failed!" diff --git a/tests/ut/utils/test_rejection_sampling.py b/tests/ut/utils/test_rejection_sampling.py new file mode 100644 index 0000000..b5480e1 --- /dev/null +++ b/tests/ut/utils/test_rejection_sampling.py @@ -0,0 +1,69 @@ +# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. + +from unittest.mock import patch +import pytest +from vllm import RequestOutput, CompletionOutput + +from mindspeed_rl.models.vllm_gen_rm import GenRM +from mindspeed_rl.models.vllm_infer import InferActor +from mindspeed_rl.utils.rejection_sampler import rejection_sampling_processor +from tests.test_tools.dist_test import DistributedTest + + +class TestRejectionSampling(DistributedTest): + + @pytest.fixture + def setUp(self): + map_keys = {"prompt": "prompt", "gt_answer": "gt_answer", "response": "response"} + self.rj = GenRM(model_path="", map_keys=map_keys) + + def test_initialization(self, setUp): + assert self.rj.input_key == "prompt" + assert self.rj.response_key == "response" + assert self.rj.gt_answer_key == "gt_answer" + + @patch.object(InferActor, "generate_sequences") + def test_get_gen_rm_rewards(self, mock_generate_sequences, setUp): + mock_generate_sequences.return_value = [ + RequestOutput(request_id="0", prompt=None, prompt_token_ids=[0], prompt_logprobs=None, finished=True, + outputs=[ + CompletionOutput(index=0, text="0.5", token_ids=(0, 0), + cumulative_logprob=None, logprobs=None)], ), + RequestOutput(request_id="1", prompt=None, prompt_token_ids=[0], prompt_logprobs=None, finished=True, + outputs=[ + CompletionOutput(index=0, text=" 1 ", token_ids=(0, 0), + cumulative_logprob=None, logprobs=None)]), + RequestOutput(request_id="2", prompt=None, prompt_token_ids=[0], prompt_logprobs=None, finished=True, + outputs=[ + CompletionOutput(index=0, text="0.8", token_ids=(0, 0), + cumulative_logprob=None, logprobs=None)]) + ] + + data = [{"prompt": "A", "response": "AA", "gt_answer": "AA", "judgement_prompt": "AAA"}, + {"prompt": "B", "response": "BB", "gt_answer": "BB", "judgement_prompt": "BBB"}, + {"prompt": "C", "response": "CC", "gt_answer": "CC", "judgement_prompt": "CCC"}] + + output = self.rj.get_gen_rm_rewards(data) + + assert output[0]["reward"] == 0.5 + assert output[1]["reward"] == 1 + assert output[2]["reward"] == -1 + + def test_rejection_sampling_processor(self): + data = [{"prompt": "A", "response": "AA", "reward": 0.8}, + {"prompt": "A", "response": "BB", "reward": 0.9}, + {"prompt": "C", "response": "CC", "reward": 1.0}, + {"prompt": "C", "response": "DD", "reward": -1}, + ] + input_key = "prompt" + response_key = "response" + output = rejection_sampling_processor(data, input_key, response_key) + output.sort(key=lambda x: x["prompt"]) + + assert len(output) == 2 + assert output[0]["prompt"] == "A" + assert output[0]["response"] == "BB" + assert output[0]["reward"] == 0.9 + assert output[1]["prompt"] == "C" + assert output[1]["response"] == "CC" + assert output[1]["reward"] == 1.0 -- Gitee From 33490ce86b17b5855691bd770efa7207a074e594 Mon Sep 17 00:00:00 2001 From: linqihong Date: Thu, 10 Apr 2025 20:46:19 +0800 Subject: [PATCH 2/4] clean code --- cli/rejection_sampling.py | 2 +- mindspeed_rl/datasets/infer_dataset.py | 4 ++-- tests/st/rejection_sampling/test_rj.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cli/rejection_sampling.py b/cli/rejection_sampling.py index 79400b9..0ad82c2 100644 --- a/cli/rejection_sampling.py +++ b/cli/rejection_sampling.py @@ -1,9 +1,9 @@ # Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. import gc from typing import Dict, Any +import math import hydra -import math import ray import jsonlines from vllm.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel diff --git a/mindspeed_rl/datasets/infer_dataset.py b/mindspeed_rl/datasets/infer_dataset.py index dd1d375..78ea0e1 100644 --- a/mindspeed_rl/datasets/infer_dataset.py +++ b/mindspeed_rl/datasets/infer_dataset.py @@ -1,5 +1,5 @@ # Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. -# Copyright (c) 2023-2025 The OpenRLHF Authors. +# Copyright (c) 2025, OpenRLHF. All Rights Reserved. import os from typing import Dict, Any, Callable @@ -64,7 +64,7 @@ def load_datasets( train_data_list = [] - for i, dataset in enumerate(datasets_path): + for _, dataset in enumerate(datasets_path): print(f"dataset: {dataset}") data_dir = dataset.split("@")[1].strip() if "@" in dataset else "" diff --git a/tests/st/rejection_sampling/test_rj.py b/tests/st/rejection_sampling/test_rj.py index 2ed5451..8e0c1c0 100644 --- a/tests/st/rejection_sampling/test_rj.py +++ b/tests/st/rejection_sampling/test_rj.py @@ -2,9 +2,9 @@ import gc import time from typing import Dict, Any +import math import hydra -import math import ray from vllm.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel -- Gitee From 261a6f1c449eb35080791cd644c2a5ae2ce222da Mon Sep 17 00:00:00 2001 From: linqihong Date: Fri, 11 Apr 2025 10:32:02 +0800 Subject: [PATCH 3/4] clean code --- tests/st/rejection_sampling/rj_qwen25_7b.yaml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/st/rejection_sampling/rj_qwen25_7b.yaml b/tests/st/rejection_sampling/rj_qwen25_7b.yaml index ef94296..1797812 100644 --- a/tests/st/rejection_sampling/rj_qwen25_7b.yaml +++ b/tests/st/rejection_sampling/rj_qwen25_7b.yaml @@ -1,5 +1,5 @@ rj_config: - model_path: /home/cq/code/MindSpeed-LLM-lxy/MindSpeed-LLM-dl/model_from_hf/Qwen2.5-7B-Instruct + model_path: /data/for_dt/weights/Qwen2.5-7B-Instruct dataset_path: /data/for_dt/datasets/rejection_sampling/generate_output_math.jsonl task: rejection_sampling map_keys: '{"prompt":"prompt", "gt_answer":"gt_answer", "response":"response"}' @@ -20,6 +20,4 @@ rj_config: repetition_penalty: 1.0 top_p: 0.3 top_k: 3 - seed: 1234 - -#/data/for_dt/weights/Qwen2.5-7B-Instruct: \ No newline at end of file + seed: 1234 \ No newline at end of file -- Gitee From 648eabfc19cff2cf4a99691215a6099b0c1a105f Mon Sep 17 00:00:00 2001 From: linqihong Date: Fri, 11 Apr 2025 15:51:26 +0800 Subject: [PATCH 4/4] =?UTF-8?q?requirements=E4=BF=AE=E5=A4=8D=E4=B8=BA?= =?UTF-8?q?=E5=8E=9F=E6=8D=A2=E8=A1=8C=E7=AC=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements.txt | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/requirements.txt b/requirements.txt index 556ddae..47c1662 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,16 +1,16 @@ -codetiming -datasets -hydra-core==1.3.2 -jsonlines -latex2sympy2 -numpy==1.26.4 -omegaconf -pybind11 -ray==2.42.1 -regex -sympy -tensorboard -tensordict -transformers==4.48.2 -word2number -wandb +codetiming +datasets +hydra-core==1.3.2 +jsonlines +latex2sympy2 +numpy==1.26.4 +omegaconf +pybind11 +ray==2.42.1 +regex +sympy +tensorboard +tensordict +transformers==4.48.2 +word2number +wandb -- Gitee