diff --git a/cli/rejection_sampling.py b/cli/rejection_sampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ad82c240a4ec27a9c9149b7c03388f4debdbf14
--- /dev/null
+++ b/cli/rejection_sampling.py
@@ -0,0 +1,189 @@
+# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved.
+import gc
+from typing import Dict, Any
+import math
+
+import hydra
+import ray
+import jsonlines
+from vllm.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel
+
+if not ray.is_initialized():
+ ray.init()
+
+from mindspeed_rl import get_tokenizer
+from mindspeed_rl.models.vllm_infer import RayInferActor
+from mindspeed_rl.utils.loggers import Loggers
+from mindspeed_rl.config_cls.rejection_sampling_config import RejectionSamplingConfig
+from mindspeed_rl.datasets.infer_dataset import load_datasets, PromptGtAnswerDataset
+from mindspeed_rl.models.vllm_gen_rm import RayGenRM
+from mindspeed_rl.utils.rejection_sampler import filtering_by_rules, rejection_sampling_processor
+
+logger = Loggers('rejection_sampling')
+
+
+def clean_up():
+ destroy_model_parallel()
+ destroy_distributed_environment()
+ gc.collect()
+
+
+def parse_config(config: Dict):
+ """
+ 解析配置,提取拒绝采样配置、vllm推理配置。
+
+ config: 输入的全局配置字典。
+ return: 拒绝采样配置、vllm推理配置。
+ """
+ rj_config = RejectionSamplingConfig(config.get("rj_config"))
+ infer_config = rj_config.sampling_config
+ return rj_config, infer_config
+
+
+def batch_generate_vllm(rj_config: RejectionSamplingConfig, infer_config: Dict[str, Any]) -> None:
+ """
+ 用ray启动多个vllm实例进行批量推理,保存prompt、response、gt_answer
+ """
+ tokenizer = get_tokenizer(rj_config.model_path)
+
+ prompts_data = load_datasets(
+ rj_config.dataset_path,
+ rj_config.probabilities,
+ seed=infer_config.seed
+ )
+
+ if rj_config.iter is None or rj_config.iter < 0:
+ prompts_data = prompts_data.select(range(min(rj_config.max_samples, len(prompts_data))))
+ else:
+ start_idx = rj_config.iter * rj_config.rollout_batch_size
+ end_idx = start_idx + rj_config.rollout_batch_size
+ prompts_data = prompts_data.select(range(start_idx, min(end_idx, len(prompts_data))))
+
+ data = PromptGtAnswerDataset(prompts_data, tokenizer, rj_config.map_keys,
+ rj_config.apply_chat_template, input_template=rj_config.input_template)
+ data = list(data)
+ batches = []
+ for i in range(0, len(data), math.ceil(len(data) / rj_config.num_vllm_instances)):
+ batches.append(data[i:i + math.ceil(len(data) / rj_config.num_vllm_instances)])
+
+ prompts = [[item["prompt"] for item in batch] for batch in batches]
+ gt_answers = [[item["gt_answer"] for item in batch] for batch in batches]
+
+ vllm_instances = []
+ for _ in range(rj_config.num_vllm_instances):
+ vllm_instances.append(
+ RayInferActor.options(resources={"NPU": infer_config.infer_tensor_parallel_size})
+ .remote(
+ model_path=rj_config.model_path,
+ infer_tensor_parallel_size=infer_config.infer_tensor_parallel_size,
+ seed=infer_config.seed,
+ max_num_seqs=infer_config.max_num_seqs,
+ max_new_tokens=infer_config.max_new_tokens,
+ top_p=infer_config.top_p,
+ top_k=infer_config.top_k,
+ temperature=infer_config.temperature,
+ repetition_penalty=infer_config.repetition_penalty,
+ )
+ )
+
+ for vllm_ins in vllm_instances:
+ vllm_ins.init_vLLM.remote()
+
+ outputs = ray.get([vllm_ins.generate_sequences.remote(prompts[i] * rj_config.best_of_n) for i, vllm_ins in
+ enumerate(vllm_instances)])
+
+ output_dataset = []
+ for i in range(rj_config.num_vllm_instances):
+ for output, gt_answer in zip(outputs[i], gt_answers[i] * rj_config.best_of_n):
+ prompt = output.prompt
+ output = output.outputs[0].text
+ output_dataset.append({"prompt": prompt, "response": output, "gt_answer": gt_answer})
+
+ with jsonlines.open(rj_config.output_path, mode="w") as writer:
+ writer.write_all(output_dataset)
+
+ del vllm_instances
+
+
+def batch_rm_rejection_sampling(rj_config: RejectionSamplingConfig, infer_config: Dict[str, Any]) -> None:
+ """
+ 用ray启动多个生成式奖励模型实例进行拒绝采样
+ """
+ data = load_datasets(
+ rj_config.dataset_path,
+ probabilities=None,
+ seed=infer_config.seed,
+ max_count=rj_config.max_samples
+ )
+
+ input_key = rj_config.map_keys.get("prompt")
+ gt_answer_key = rj_config.map_keys.get("gt_answer")
+ response_key = rj_config.map_keys.get("response")
+ if rj_config.filter_by_rules:
+ data = filtering_by_rules(data, gt_answer_key, response_key,
+ rj_config.verifier_parallel,
+ rj_config.verifier_function,
+ rj_config.verifier_weight,
+ rj_config.verifier_timeout,
+ rj_config.accept_score)
+
+ batches = []
+ for i in range(0, len(data), math.ceil(len(data) / rj_config.num_vllm_instances)):
+ batches.append(data[i:i + math.ceil(len(data) / rj_config.num_vllm_instances)])
+
+ rm_instances = []
+ for _ in range(rj_config.num_vllm_instances):
+ rm_instances.append(
+ RayGenRM.options(resources={"NPU": infer_config.infer_tensor_parallel_size})
+ .remote(
+ model_path=rj_config.model_path,
+ infer_tensor_parallel_size=infer_config.infer_tensor_parallel_size,
+ seed=infer_config.seed,
+ max_num_seqs=infer_config.max_num_seqs,
+ max_new_tokens=infer_config.max_new_tokens,
+ top_p=infer_config.top_p,
+ top_k=infer_config.top_k,
+ temperature=infer_config.temperature,
+ repetition_penalty=infer_config.repetition_penalty,
+ map_keys=rj_config.map_keys,
+ use_ground_truth_answer=rj_config.use_ground_truth_answer,
+ )
+ )
+
+ for rm_ins in rm_instances:
+ rm_ins.init_vLLM.remote()
+
+ outputs = ray.get([rm_ins.run_rm_judge.remote(batches[i]) for i, rm_ins in enumerate(rm_instances)])
+
+ outputs = [item for output in outputs for item in output]
+ if rj_config.save_rm_judgement:
+ with jsonlines.open(rj_config.rm_judgement_path, mode="w") as writer:
+ writer.write_all(outputs)
+
+ res = rejection_sampling_processor(outputs, input_key, response_key)
+
+ with jsonlines.open(rj_config.output_path, mode="w") as writer:
+ writer.write_all(res)
+
+ print(f"Processing complete and data saved to '{rj_config.output_path}'.")
+
+ del rm_instances
+
+
+@hydra.main(config_path='../configs', config_name='batch_generate_vllm_qwen25_7b', version_base=None)
+def main(config):
+ rj_config, infer_config = parse_config(config)
+
+ if rj_config.task == "generate_vllm":
+ batch_generate_vllm(rj_config, infer_config)
+ elif rj_config.task == "rejection_sampling":
+ batch_rm_rejection_sampling(rj_config, infer_config)
+ else:
+ logger.info(
+ "Invalid or missing '--task' argument. Please specify either 'vllm_generate' or 'rejection_sampling'.")
+
+ clean_up()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/mindspeed_rl/config_cls/rejection_sampling_config.py b/mindspeed_rl/config_cls/rejection_sampling_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e016433e913f3dde674922b240ed674c514d98f
--- /dev/null
+++ b/mindspeed_rl/config_cls/rejection_sampling_config.py
@@ -0,0 +1,82 @@
+import json
+from mindspeed_rl.config_cls.base_config import BaseConfig
+
+
+class RejectionSamplingConfig(BaseConfig):
+ '''
+ Initialize model configuration from the provided config dictionary.
+ All instance attributes are initialized using the dictionary keys.
+
+ task: The specific task name,generate_vllm or rejection_sampling, (default: None)
+ model_path: Path of the model (default: "")
+ dataset_path: Path of the dataset (default: "")
+ output_path: Path to save the model (default: "")
+ save_rm_judgement: Enable saving reward model judgements (default: False)
+ rm_judgement_path: Path to save reward model judgements (default: "")
+ map_keys: Dataset keys mapping, (default'{"prompt":"input","gt_answer":"gt_answer","response":""}')
+ probabilities: Mixing probabilities of multiple data sets (default: None)
+ input_template: User-defined template (default: None)
+ apply_chat_template: Whether to use the template provided by the model.
+
+ iter: Data slice, ranging from iter * rollout_batch_size: (iter + 1) * rollout_batch_size (default: 0)
+ rollout_batch_size: Number of rollout (default: 100)
+ best_of_n: Number of responses generated by each prompt (default: 8)
+ max_samples: Maximum number of data samples. If iter is not set, this parameter is used. (default: 5000000)
+ use_ground_truth_answer: Whether to use the ground truth answer from the dataset. (default: False)
+ num_vllm_instances: The number of running vLLM instances. (default: 1)
+ infer_tensor_parallel_size: Tensor parallel size during evaluation. (default: 1)
+
+ filter_by_rules: Whether to use the rules to filter the dataset. (default: False)
+ accept_score: Threshold for the rule score used in filtering data. (default: 0.5)
+
+ max_new_tokens: Length of the output generated text. (default: 1024)
+ max_num_seqs: Maximum number of sequences to process simultaneously of vLLM definition. (default: 64)
+ top_p: The cumulative probability threshold for nucleus sampling. (default: 0.8)
+ top_k: The number of highest - probability tokens to consider for sampling. (default: 5)
+ temperature: Controls the randomness of predictions by scaling the logits before applying softmax. (default: 0.5)
+ repetition_penalty: Control the ratio of repeat content, it will penalize repeating generation. Default is 1 represents no penalty.
+ seed: Random seeds. (default: 1234)
+ '''
+
+ def __init__(self, config_dict):
+ self.task = None
+ self.model_path = ""
+ self.dataset_path = ""
+ self.output_path = ""
+ self.rm_judgement_path = ""
+ self.save_rm_judgement = False
+ self.map_keys = '{"prompt":"input","gt_answer":"gt_answer","response":""}'
+ self.probabilities = None
+ self.input_template = None
+ self.apply_chat_template = False
+
+ self.iter = 0
+ self.rollout_batch_size = 100
+ self.best_of_n = 8
+ self.max_samples = 5000000
+ self.use_ground_truth_answer = False
+
+ self.infer_tensor_parallel_size = 1
+ self.num_vllm_instances = 1
+
+ self.filter_by_rules = False
+ self.verifier_function = ["acc", ]
+ self.verifier_weight = [1.0, ]
+ self.verifier_parallel = 4
+ self.verifier_timeout = 120
+ self.accept_score = 0.5
+
+ self.sampling_config = {
+ "infer_tensor_parallel_size": 1, # tensor并行数
+ "max_new_tokens": 1024, # 生成输出的最大 token 数量
+ "max_num_seqs": 64, # 同时处理的最大序列数
+ "top_p": 0.8, # 核采样的累积概率阈值
+ "top_k": 5, # 采样时考虑的最高概率 token 的数量
+ "temperature": 0.5, # 控制预测随机性的温度参数
+ "repetition_penalty": 1.0, # 重复惩罚系数
+ "seed": 1234 # 随机数种子
+ }
+
+ config_dict['map_keys'] = json.loads(config_dict['map_keys'] if 'map_keys' in config_dict.keys() else
+ self.map_keys)
+ self.update(config_dict)
diff --git a/mindspeed_rl/datasets/infer_dataset.py b/mindspeed_rl/datasets/infer_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..78ea0e128e223313053f4637c38f96c266c1eb26
--- /dev/null
+++ b/mindspeed_rl/datasets/infer_dataset.py
@@ -0,0 +1,130 @@
+# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved.
+# Copyright (c) 2025, OpenRLHF. All Rights Reserved.
+
+import os
+from typing import Dict, Any, Callable
+
+from datasets import interleave_datasets, load_dataset, load_from_disk
+from datasets.combine import DatasetType
+from torch.utils.data import Dataset
+from tqdm import tqdm
+
+
+def preprocess_data(data: Dict[str, Any], input_template: str = None, input_key: str = "input",
+ apply_chat_template: Callable = None) -> str:
+ if apply_chat_template:
+ chat = data[input_key]
+ if isinstance(chat, str):
+ chat = [{"role": "user", "content": chat}]
+ prompt = apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
+ else:
+ prompt = data[input_key] if input_key in data.keys() else None
+ if input_template:
+ prompt = input_template.format(prompt)
+ return prompt
+
+
+def load_single_dataset(dataset: str, data_dir: str = None):
+ """
+ Load a single dataset given its path or identifier.
+ Supports loading from different sources: local files, remote datasets, or directories.
+ """
+ dataset = dataset.strip()
+ ext = os.path.splitext(dataset)[-1]
+
+ # local python script
+ if ext == ".py" or (
+ os.path.isdir(dataset) and os.path.exists(os.path.join(dataset, f"{os.path.basename(dataset)}.py"))):
+ data = load_dataset(dataset, trust_remote_code=True)
+ # local text file
+ elif ext in [".json", ".jsonl", ".csv", ".parquet"]:
+ ext = ext.lower().strip(".")
+ if ext == "jsonl":
+ ext = "json"
+ data = load_dataset(ext, data_files=dataset)
+ # remote/local folder or common file
+ else:
+ data = load_dataset(dataset, data_dir=data_dir)
+ return data
+
+
+def load_datasets(
+ datasets_path: str,
+ probabilities: str = None,
+ seed: int = 42,
+ max_count: int = -1,
+ train_split: str = "train"
+) -> DatasetType:
+ datasets_path = datasets_path.split(",")
+ if probabilities is None:
+ probabilities = [1.0] * len(datasets_path)
+ else:
+ probabilities = list(map(float, probabilities.split(",")))
+ assert len(probabilities) == len(datasets_path)
+
+ train_data_list = []
+
+ for _, dataset in enumerate(datasets_path):
+ print(f"dataset: {dataset}")
+
+ data_dir = dataset.split("@")[1].strip() if "@" in dataset else ""
+ dataset_name = dataset.split("@")[0].strip()
+
+ # Call load_single_dataset to load the dataset
+ data = load_single_dataset(dataset_name, data_dir)
+
+ # Select train data
+ if train_split and train_split in data:
+ max_count = min(max_count, len(data[train_split])) if max_count > 0 else len(data[train_split])
+ train_data = data[train_split].select(range(max_count))
+ else:
+ max_count = min(max_count, len(data)) if max_count > 0 else len(data)
+ train_data = data.select(range(max_count))
+ train_data_list.append(train_data)
+
+ # Merge datasets
+ train_dataset = interleave_datasets(
+ train_data_list,
+ probabilities=probabilities,
+ seed=seed,
+ )
+
+ return train_dataset
+
+
+class PromptGtAnswerDataset(Dataset):
+ '''
+ Base configuration class.
+ '''
+ def __init__(
+ self,
+ dataset: DatasetType,
+ tokenizer: Callable,
+ map_keys: Dict[str, str],
+ apply_chat_template: bool = False,
+ input_template: str = None,
+ ) -> None:
+ super().__init__()
+ self.tokenizer = tokenizer
+
+ self.input_template = input_template
+ input_key = map_keys.get("prompt", None)
+ gt_answer_key = map_keys.get("gt_answer", None)
+
+ if apply_chat_template:
+ apply_chat_template = self.tokenizer.apply_chat_template
+
+ self.prompts = []
+ self.gt_answers = []
+ for data in tqdm(dataset, desc="Preprocessing data"):
+ prompt = preprocess_data(data, input_template, input_key, apply_chat_template)
+ gt_answer = preprocess_data(data, input_key=gt_answer_key)
+ self.prompts.append(prompt)
+ self.gt_answers.append(gt_answer)
+
+ def __len__(self):
+ length = len(self.prompts)
+ return length
+
+ def __getitem__(self, idx):
+ return {"prompt": self.prompts[idx], "gt_answer": self.gt_answers[idx]}
diff --git a/mindspeed_rl/models/vllm_gen_rm.py b/mindspeed_rl/models/vllm_gen_rm.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2b1dfb22d5a80b6b82424ad05594ed11aaff1b6
--- /dev/null
+++ b/mindspeed_rl/models/vllm_gen_rm.py
@@ -0,0 +1,146 @@
+# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved.
+import re
+from typing import List, Dict, Any
+import ray
+
+from mindspeed_rl.models.vllm_infer import InferActor
+
+
+class GenRM(InferActor):
+ """
+ Generative reward model using VLLM for inference.
+
+ Args:
+ model_path: Path of the model.
+ infer_tensor_parallel_size: Tensor parallel size during evaluation. (default: 1)
+ seed: Random seeds. (default: 1234)
+ max_num_seqs: Maximum number of sequences to process simultaneously of vLLM definition. (default: 64)
+ max_new_tokens: Length of the output generated text. (default: 2048)
+ top_p: The cumulative probability threshold for nucleus sampling. (default: 0.8)
+ top_k: The number of highest - probability tokens to consider for sampling. (default: 10)
+ temperature: Controls the randomness of predictions by scaling the logits before applying softmax. (default: 1.0)
+ repetition_penalty: Control the ratio of repeat content, it will penalize repeating generation. Default is 1 represents no penalty.
+ map_keys: Dataset keys mapping, (default: None)
+ Whether to use the ground truth answer from the dataset. (default: False)
+ """
+ def __init__(self,
+ model_path: str,
+ infer_tensor_parallel_size: int = 1,
+ seed: int = 1234,
+ max_num_seqs: int = 64,
+ max_new_tokens: int = 2048,
+ top_p: float = 0.8,
+ top_k: int = 10,
+ temperature: float = 1.0,
+ repetition_penalty: float = 1.0,
+ map_keys: Dict[str, str] = None,
+ use_ground_truth_answer: bool = False,
+ ):
+ super(GenRM, self).__init__(model_path,
+ infer_tensor_parallel_size,
+ seed,
+ max_num_seqs,
+ max_new_tokens,
+ top_p,
+ top_k,
+ temperature,
+ repetition_penalty)
+ self.input_key = map_keys.get("prompt")
+ self.response_key = map_keys.get("response")
+ self.gt_answer_key = map_keys.get("gt_answer")
+ self.use_ground_truth_answer = use_ground_truth_answer
+
+ def process_row_function(self, example):
+ """
+ 处理每行数据,包括提取需要字段,加载生成式奖励模型的评估模版
+ """
+ prompt_text = example[self.input_key]
+ response_text = example[self.response_key]
+ if self.use_ground_truth_answer:
+ gt_answer_text = example[self.gt_answer_key]
+ else:
+ gt_answer_text = None
+
+ judgement_prompt = self.apply_gen_rm_template(prompt_text, response_text, gt_answer_text)
+ example['judgement_prompt'] = judgement_prompt
+ return example
+
+ def get_gen_rm_rewards(self, input_data: List[Dict[str, Any]]):
+ """
+ 对回复进行生成式奖励模型评判,解析评判文本,得到reward分数
+ """
+ judgement_prompts = [item['judgement_prompt'] for item in input_data]
+ judgements = self.generate_sequences(judgement_prompts)
+
+ output_dataset = []
+ for example, judgement in zip(input_data, judgements):
+ example["judgement"] = judgement.outputs[0].text
+ judgement_parsing = re.findall(r'(\s*-?\d+(?:\.\d+)?\s*)', example["judgement"])
+ if judgement_parsing:
+ example["reward"] = float(judgement_parsing[0])
+ else:
+ example["reward"] = -1
+ output_dataset.append(example)
+ return output_dataset
+
+ def apply_gen_rm_template(self, prompt_text: str, response_text: str, ground_truth_answer_text: str) -> str:
+ """
+ 加载生成式奖励模型的评估模版,包括有无正确答案的两套模板
+ """
+ if ground_truth_answer_text:
+ full_input = f"""<|im_start|>你是一个判别推理正确性的专家。
+ 问题[PROMPT]: {prompt_text}
+ 正确答案[GROUND TRUTH]: {ground_truth_answer_text}
+ 回复[RESPONSE]: {response_text}
+ 任务目标:根据给定的问题[PROMPT]和正确答案[GROUND TRUTH],评估回复[RESPONSE]质量。重点考虑回复结果的正确性,\
+ 其次考虑语言一致性、格式正确性、推理合理性、语句重复冗余性,用简洁的文字说明原因。\
+ 最后给出0到1之间的分数,分数以score here的形式给出。<|im_end|>
+ <|im_start|>assistant
+ """
+ else:
+ full_input = f"""<|im_start|>你是一个判别推理正确性的专家。
+ 问题[PROMPT]: {prompt_text}
+ 回复[RESPONSE]: {response_text}
+ 任务目标:根据给定的问题[PROMPT],评估回复[RESPONSE]质量,考虑回复结果的正确性、语言一致性、\
+ 格式正确性、推理合理性、回复无害性、语句重复冗余性,用简洁的文字说明原因。\
+ 最后给出0到1之间的分数,分数以score here的形式给出。<|im_end|>
+ <|im_start|>assistant
+ """
+ return full_input
+
+ def run_rm_judge(self, input_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """
+ 运行生成式奖励模型评判函数
+ """
+ data = []
+ for item in input_data:
+ data.append(self.process_row_function(item))
+ res = self.get_gen_rm_rewards(data)
+ return res
+
+
+@ray.remote
+class RayGenRM(GenRM):
+ def __init__(self,
+ model_path: str,
+ infer_tensor_parallel_size: int = 1,
+ seed: int = 1234,
+ max_num_seqs: int = 64,
+ max_new_tokens: int = 2048,
+ top_p: float = 0.8,
+ top_k: int = 10,
+ temperature: float = 1.0,
+ repetition_penalty: float = 1.0,
+ map_keys: Dict[str, str] = None,
+ use_ground_truth_answer: bool = False):
+ super().__init__(model_path,
+ infer_tensor_parallel_size,
+ seed,
+ max_num_seqs,
+ max_new_tokens,
+ top_p,
+ top_k,
+ temperature,
+ repetition_penalty,
+ map_keys,
+ use_ground_truth_answer)
diff --git a/mindspeed_rl/models/vllm_infer.py b/mindspeed_rl/models/vllm_infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..76d371aec9f9841e586263ee8cac2e4e07b3034d
--- /dev/null
+++ b/mindspeed_rl/models/vllm_infer.py
@@ -0,0 +1,94 @@
+# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved.
+from typing import List
+
+import ray
+from vllm import LLM, SamplingParams
+
+
+class InferActor():
+ """
+ Infer Actor using VLLM for inference.
+
+ Args:
+ model_path: Path of the model.
+ infer_tensor_parallel_size: Tensor parallel size during evaluation. (default: 1)
+ seed: Random seeds. (default: 1234)
+ max_num_seqs: Maximum number of sequences to process simultaneously of vLLM definition. (default: 64)
+ max_new_tokens: Length of the output generated text. (default: 2048)
+ top_p: The cumulative probability threshold for nucleus sampling. (default: 0.8)
+ top_k: The number of highest - probability tokens to consider for sampling. (default: 10)
+ temperature: Controls the randomness of predictions by scaling the logits before applying softmax. (default: 1.0)
+ repetition_penalty: Control the ratio of repeat content, it will penalize repeating generation. Default is 1 represents no penalty.
+ """
+ def __init__(self,
+ model_path: str,
+ infer_tensor_parallel_size: int = 1,
+ seed: int = 1234,
+ max_num_seqs: int = 64,
+ max_new_tokens: int = 2048,
+ top_p: float = 0.8,
+ top_k: int = 10,
+ temperature: float = 1.0,
+ repetition_penalty: float = 1.0,
+ ):
+ self.llm = None
+ self.sampling_params = None
+ self.model_path = model_path
+ self.infer_tensor_parallel_size = infer_tensor_parallel_size
+ self.seed = seed
+ self.max_num_seqs = max_num_seqs
+ self.max_new_tokens = max_new_tokens
+ self.top_p = top_p
+ self.top_k = top_k
+ self.temperature = temperature
+ self.repetition_penalty = repetition_penalty
+
+ def init_vLLM(self):
+ """
+ 初始化离线推理接口 vLLM.LLM()
+ """
+ self.llm = LLM(
+ model=self.model_path,
+ tensor_parallel_size=self.infer_tensor_parallel_size,
+ trust_remote_code=True,
+ seed=self.seed,
+ max_num_seqs=self.max_num_seqs
+ )
+
+ # Create a sampling params object.
+ self.sampling_params = SamplingParams(
+ max_tokens=self.max_new_tokens,
+ top_p=self.top_p,
+ top_k=self.top_k,
+ temperature=self.temperature,
+ repetition_penalty=self.repetition_penalty
+ )
+
+ def generate_sequences(self, prompts: List[str]):
+ """generate sequences via vLLM offline inference"""
+ outputs = self.llm.generate(prompts, self.sampling_params)
+
+ return outputs
+
+
+@ray.remote
+class RayInferActor(InferActor):
+ def __init__(self,
+ model_path: str,
+ infer_tensor_parallel_size: int = 1,
+ seed: int = 1234,
+ max_num_seqs: int = 64,
+ max_new_tokens: int = 2048,
+ top_p: float = 0.8,
+ top_k: int = 10,
+ temperature: float = 1.0,
+ repetition_penalty: float = 1.0,):
+ super().__init__(model_path,
+ infer_tensor_parallel_size,
+ seed,
+ max_num_seqs,
+ max_new_tokens,
+ top_p,
+ top_k,
+ temperature,
+ repetition_penalty)
diff --git a/mindspeed_rl/utils/rejection_sampler.py b/mindspeed_rl/utils/rejection_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..998057a54f96be29b98d07f321c6b166ddbc2405
--- /dev/null
+++ b/mindspeed_rl/utils/rejection_sampler.py
@@ -0,0 +1,61 @@
+# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved.
+from typing import List, Dict, Any
+from tqdm import tqdm
+
+
+def filtering_by_rules(input_data: List[Dict[str, Any]],
+ gt_answer_key: str,
+ response_key: str,
+ verifier_parallel: int = 4,
+ verifier_function: List[str] = ["acc"],
+ verifier_weight: List[float] = [1.0],
+ verifier_timeout: int = 300,
+ accept_score: float = 0.5
+ ):
+ """
+ 拒绝采样function,对同一个prompt保留得分最高的回复
+ """
+ from types import SimpleNamespace
+ from mindspeed_rl.models.rule_verifier import verifier
+ labels = {}
+ labels["labels"] = [data[gt_answer_key] for data in input_data]
+ responses = [data[response_key] for data in input_data]
+
+ config = SimpleNamespace()
+ config.verifier_parallel = verifier_parallel
+ config.verifier_function = verifier_function
+ config.verifier_weight = verifier_weight
+ config.verifier_timeout = verifier_timeout
+
+ scores, reward = verifier(responses, labels, config)
+
+ output_dataset = []
+ for example, score in zip(input_data, scores):
+ example["rule_score"] = score
+ if score >= accept_score:
+ example['select'] = True
+ else:
+ example['select'] = False
+ output_dataset.append(example)
+ output_dataset = list(filter(lambda item: item["select"], output_dataset))
+ return output_dataset
+
+
+def rejection_sampling_processor(data: List[Dict[str, Any]], input_key: str, response_key: str):
+ """
+ 拒绝采样function,对同一个prompt保留得分最高的回复
+ """
+ out = {}
+ for item in tqdm(data, desc="Rejection Sampling process...."):
+ prompt = item[input_key]
+ response = item[response_key]
+ reward = item["reward"]
+
+ if reward > 0:
+ if prompt not in out:
+ out[prompt] = {"response": response, "reward": reward}
+ elif reward > out[prompt]["reward"]:
+ out[prompt]["reward"] = reward
+ out[prompt]["response"] = response
+
+ return [{"prompt": k, "response": v["response"], "reward": v["reward"]} for k, v in out.items()]
diff --git a/requirements.txt b/requirements.txt
index 76f105613c476d80a16ceb0fafd1f66c10ffef2f..47c1662429792d4556c98c1c39ed3e27dfb8c94f 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,7 @@
codetiming
datasets
hydra-core==1.3.2
+jsonlines
latex2sympy2
numpy==1.26.4
omegaconf
diff --git a/tests/st/rejection_sampling/rj_qwen25_7b.yaml b/tests/st/rejection_sampling/rj_qwen25_7b.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..17978120eefe31bd2bea2192577772bd502ee536
--- /dev/null
+++ b/tests/st/rejection_sampling/rj_qwen25_7b.yaml
@@ -0,0 +1,23 @@
+rj_config:
+ model_path: /data/for_dt/weights/Qwen2.5-7B-Instruct
+ dataset_path: /data/for_dt/datasets/rejection_sampling/generate_output_math.jsonl
+ task: rejection_sampling
+ map_keys: '{"prompt":"prompt", "gt_answer":"gt_answer", "response":"response"}'
+ use_ground_truth_answer: True
+ num_vllm_instances: 2
+ filter_by_rules: true
+ verifier_function: ["acc", ]
+ verifier_weight: [1.0, ]
+ verifier_parallel: 4
+ verifier_timeout: 120
+ accept_score: -1.0
+
+ sampling_config:
+ infer_tensor_parallel_size: 2
+ max_new_tokens: 512
+ max_num_seqs: 64
+ temperature: 0.3
+ repetition_penalty: 1.0
+ top_p: 0.3
+ top_k: 3
+ seed: 1234
\ No newline at end of file
diff --git a/tests/st/rejection_sampling/test_module_entry_rejection_sampling.sh b/tests/st/rejection_sampling/test_module_entry_rejection_sampling.sh
new file mode 100644
index 0000000000000000000000000000000000000000..8f1e00e3096e4db2f2bb0af5d89951d5d75f1117
--- /dev/null
+++ b/tests/st/rejection_sampling/test_module_entry_rejection_sampling.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+
+# 获取脚本的绝对路径
+SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd)
+export PYTHONPATH=$SCRIPT_DIR/../../..:$PYTHONPATH
+
+ray stop
+
+python $SCRIPT_DIR/test_rj.py --config-name rj_qwen25_7b.yaml
\ No newline at end of file
diff --git a/tests/st/rejection_sampling/test_rj.py b/tests/st/rejection_sampling/test_rj.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e0c1c0fbe3d7103dc39478a69714192b2563d60
--- /dev/null
+++ b/tests/st/rejection_sampling/test_rj.py
@@ -0,0 +1,123 @@
+# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved.
+import gc
+import time
+from typing import Dict, Any
+import math
+
+import hydra
+import ray
+from vllm.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel
+
+if not ray.is_initialized():
+ ray.init()
+
+from mindspeed_rl.datasets.infer_dataset import load_datasets
+from mindspeed_rl.config_cls.rejection_sampling_config import RejectionSamplingConfig
+from mindspeed_rl.models.vllm_gen_rm import RayGenRM
+from mindspeed_rl.utils.loggers import Loggers
+from mindspeed_rl.utils.rejection_sampler import rejection_sampling_processor, filtering_by_rules
+
+logger = Loggers('rejection_sampling')
+
+
+def clean_up():
+ destroy_model_parallel()
+ destroy_distributed_environment()
+ gc.collect()
+
+
+def parse_config(config: Dict):
+ """
+ 解析配置,提取拒绝采样配置、vllm推理配置。
+
+ config: 输入的全局配置字典。
+ return: 拒绝采样配置、vllm推理配置。
+ """
+ rj_config = RejectionSamplingConfig(config.get("rj_config"))
+ infer_config = rj_config.sampling_config
+ return rj_config, infer_config
+
+
+def batch_rm_rejection_sampling(rj_config: RejectionSamplingConfig, infer_config: Dict[str, Any]) -> None:
+ """
+ 用ray启动多个生成式奖励模型实例进行拒绝采样
+ """
+ data = load_datasets(
+ rj_config.dataset_path,
+ probabilities=None,
+ seed=infer_config.seed,
+ max_count=rj_config.max_samples
+ )
+
+ input_key = rj_config.map_keys.get("prompt")
+ gt_answer_key = rj_config.map_keys.get("gt_answer")
+ response_key = rj_config.map_keys.get("response")
+ if rj_config.filter_by_rules:
+ data = filtering_by_rules(data, gt_answer_key, response_key,
+ rj_config.verifier_parallel,
+ rj_config.verifier_function,
+ rj_config.verifier_weight,
+ rj_config.verifier_timeout,
+ rj_config.accept_score)
+
+ batches = []
+ for i in range(0, len(data), math.ceil(len(data) / rj_config.num_vllm_instances)):
+ batches.append(data[i:i + math.ceil(len(data) / rj_config.num_vllm_instances)])
+
+ rm_instances = []
+ for _ in range(rj_config.num_vllm_instances):
+ rm_instances.append(
+ RayGenRM.options(resources={"NPU": infer_config.infer_tensor_parallel_size})
+ .remote(
+ model_path=rj_config.model_path,
+ infer_tensor_parallel_size=infer_config.infer_tensor_parallel_size,
+ seed=infer_config.seed,
+ max_num_seqs=infer_config.max_num_seqs,
+ max_new_tokens=infer_config.max_new_tokens,
+ top_p=infer_config.top_p,
+ top_k=infer_config.top_k,
+ temperature=infer_config.temperature,
+ repetition_penalty=infer_config.repetition_penalty,
+ map_keys=rj_config.map_keys,
+ use_ground_truth_answer=rj_config.use_ground_truth_answer,
+ )
+ )
+
+ for rm_ins in rm_instances:
+ rm_ins.init_vLLM.remote()
+
+ outputs = ray.get([rm_ins.run_rm_judge.remote(batches[i]) for i, rm_ins in enumerate(rm_instances)])
+ outputs = [item for output in outputs for item in output]
+
+ res = rejection_sampling_processor(outputs, input_key, response_key)
+
+ del rm_instances
+
+ return res
+
+
+@hydra.main(config_path='./', config_name='rj_qwen25_7b', version_base=None)
+def main(config):
+ rj_config, infer_config = parse_config(config)
+ a = time.time()
+
+ if rj_config.task == "rejection_sampling":
+ data_path = None
+ if 'rejection_sampling' in rj_config.dataset_path:
+ data_path = rj_config.dataset_path
+
+ try:
+ if data_path:
+ result = batch_rm_rejection_sampling(rj_config, infer_config)
+ logger.info('\n rejection sampling result length: {}'.format(len(result)))
+ logger.info('\n rejection sampling result: {}'.format(result))
+ except Exception as e:
+ logger.info(e)
+
+ logger.info(f'Rejection Sampling Running Time:, {time.time() - a}')
+
+ clean_up()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tests/ut/datasets/test_prompt_gt_answer_dataset.py b/tests/ut/datasets/test_prompt_gt_answer_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..f916c3637e0235851fb819f5db24f2484154526a
--- /dev/null
+++ b/tests/ut/datasets/test_prompt_gt_answer_dataset.py
@@ -0,0 +1,54 @@
+# coding=utf-8
+# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved.
+
+from datasets import Dataset
+
+from mindspeed_rl.datasets.infer_dataset import PromptGtAnswerDataset
+from mindspeed_rl.utils.tokenizer import get_tokenizer
+
+from tests.test_tools.dist_test import DistributedTest
+
+
+class TestPromptGtAnswerDataset(DistributedTest):
+ world_size = 1
+
+ def test_prompt_gt_answer_dataset(self):
+ tokenizer_directory = '/data/models/llama2-7b'
+ hf_tokenizer = get_tokenizer(tokenizer_directory)
+ map_keys = {"prompt": "input", "gt_answer": "gt_answer"}
+ input_data = Dataset.from_dict({"input": ["AAA", "BBB"], "gt_answer": ["A", "B"]})
+
+ dataset = PromptGtAnswerDataset(
+ dataset=input_data,
+ tokenizer=hf_tokenizer,
+ map_keys=map_keys,
+ apply_chat_template=False,
+ input_template=None,
+ )
+ assert len(dataset) == 2, "The __len__ method of the PromptGtAnswerDataset failed!"
+
+ assert dataset[0]["prompt"] == "AAA", "The __get_item__ method of the PromptGtAnswerDataset failed!"
+ assert dataset[0]["gt_answer"] == "A", "The __get_item__ method of the PromptGtAnswerDataset failed!"
+ assert dataset[1]["prompt"] == "BBB", "The __get_item__ method of the PromptGtAnswerDataset failed!"
+ assert dataset[1]["gt_answer"] == "B", "The __get_item__ method of the PromptGtAnswerDataset failed!"
+
+ def test_input_template(self):
+ tokenizer_directory = '/data/models/llama2-7b'
+ hf_tokenizer = get_tokenizer(tokenizer_directory)
+ map_keys = {"prompt": "input", "gt_answer": "gt_answer"}
+ input_data = Dataset.from_dict({"input": ["AAA", "BBB"], "gt_answer": ["A", "B"]})
+ input_template = "<|im_start|>You are a helpful assitant.<|im_end|>User:{} Assitant:"
+ dataset = PromptGtAnswerDataset(
+ dataset=input_data,
+ tokenizer=hf_tokenizer,
+ map_keys=map_keys,
+ apply_chat_template=False,
+ input_template=input_template,
+ )
+ assert len(dataset) == 2, "The __len__ method of the PromptGtAnswerDataset failed!"
+ assert dataset[0]["prompt"] == "<|im_start|>You are a helpful assitant.<|im_end|>User:AAA Assitant:", \
+ "The __get_item__ method of the PromptGtAnswerDataset failed!"
+ assert dataset[0]["gt_answer"] == "A", "The __get_item__ method of the PromptGtAnswerDataset failed!"
+ assert dataset[1]["prompt"] == "<|im_start|>You are a helpful assitant.<|im_end|>User:BBB Assitant:", \
+ "The __get_item__ method of the PromptGtAnswerDataset failed!"
+ assert dataset[1]["gt_answer"] == "B", "The __get_item__ method of the PromptGtAnswerDataset failed!"
diff --git a/tests/ut/utils/test_rejection_sampling.py b/tests/ut/utils/test_rejection_sampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5480e12292ca42c781bd4d5dac27f23cc871c4e
--- /dev/null
+++ b/tests/ut/utils/test_rejection_sampling.py
@@ -0,0 +1,69 @@
+# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved.
+
+from unittest.mock import patch
+import pytest
+from vllm import RequestOutput, CompletionOutput
+
+from mindspeed_rl.models.vllm_gen_rm import GenRM
+from mindspeed_rl.models.vllm_infer import InferActor
+from mindspeed_rl.utils.rejection_sampler import rejection_sampling_processor
+from tests.test_tools.dist_test import DistributedTest
+
+
+class TestRejectionSampling(DistributedTest):
+
+ @pytest.fixture
+ def setUp(self):
+ map_keys = {"prompt": "prompt", "gt_answer": "gt_answer", "response": "response"}
+ self.rj = GenRM(model_path="", map_keys=map_keys)
+
+ def test_initialization(self, setUp):
+ assert self.rj.input_key == "prompt"
+ assert self.rj.response_key == "response"
+ assert self.rj.gt_answer_key == "gt_answer"
+
+ @patch.object(InferActor, "generate_sequences")
+ def test_get_gen_rm_rewards(self, mock_generate_sequences, setUp):
+ mock_generate_sequences.return_value = [
+ RequestOutput(request_id="0", prompt=None, prompt_token_ids=[0], prompt_logprobs=None, finished=True,
+ outputs=[
+ CompletionOutput(index=0, text="0.5", token_ids=(0, 0),
+ cumulative_logprob=None, logprobs=None)], ),
+ RequestOutput(request_id="1", prompt=None, prompt_token_ids=[0], prompt_logprobs=None, finished=True,
+ outputs=[
+ CompletionOutput(index=0, text=" 1 ", token_ids=(0, 0),
+ cumulative_logprob=None, logprobs=None)]),
+ RequestOutput(request_id="2", prompt=None, prompt_token_ids=[0], prompt_logprobs=None, finished=True,
+ outputs=[
+ CompletionOutput(index=0, text="0.8", token_ids=(0, 0),
+ cumulative_logprob=None, logprobs=None)])
+ ]
+
+ data = [{"prompt": "A", "response": "AA", "gt_answer": "AA", "judgement_prompt": "AAA"},
+ {"prompt": "B", "response": "BB", "gt_answer": "BB", "judgement_prompt": "BBB"},
+ {"prompt": "C", "response": "CC", "gt_answer": "CC", "judgement_prompt": "CCC"}]
+
+ output = self.rj.get_gen_rm_rewards(data)
+
+ assert output[0]["reward"] == 0.5
+ assert output[1]["reward"] == 1
+ assert output[2]["reward"] == -1
+
+ def test_rejection_sampling_processor(self):
+ data = [{"prompt": "A", "response": "AA", "reward": 0.8},
+ {"prompt": "A", "response": "BB", "reward": 0.9},
+ {"prompt": "C", "response": "CC", "reward": 1.0},
+ {"prompt": "C", "response": "DD", "reward": -1},
+ ]
+ input_key = "prompt"
+ response_key = "response"
+ output = rejection_sampling_processor(data, input_key, response_key)
+ output.sort(key=lambda x: x["prompt"])
+
+ assert len(output) == 2
+ assert output[0]["prompt"] == "A"
+ assert output[0]["response"] == "BB"
+ assert output[0]["reward"] == 0.9
+ assert output[1]["prompt"] == "C"
+ assert output[1]["response"] == "CC"
+ assert output[1]["reward"] == 1.0