From a69aab8dcc3dee6311883e4711665ebbfad1254d Mon Sep 17 00:00:00 2001 From: zhuyilin Date: Sun, 23 Mar 2025 11:06:28 +0800 Subject: [PATCH] add evaluate --- evaluation.py | 131 +++++++++--- .../evaluate_qwen25_math_7b_ptd_all.sh | 74 +++++++ .../evaluate_qwen25_math_7b_ptd_single.sh | 76 +++++++ .../tasks/evaluation/eval_impl/aime_eval.py | 167 +++++++++++++++ .../tasks/evaluation/eval_impl/gpqa_eval.py | 162 +++++++++++++++ .../evaluation/eval_impl/math500_eval.py | 190 ++++++++++++++++++ 6 files changed, 767 insertions(+), 33 deletions(-) create mode 100644 examples/mcore/qwen25_math/evaluate_qwen25_math_7b_ptd_all.sh create mode 100644 examples/mcore/qwen25_math/evaluate_qwen25_math_7b_ptd_single.sh create mode 100644 mindspeed_llm/tasks/evaluation/eval_impl/aime_eval.py create mode 100644 mindspeed_llm/tasks/evaluation/eval_impl/gpqa_eval.py create mode 100644 mindspeed_llm/tasks/evaluation/eval_impl/math500_eval.py diff --git a/evaluation.py b/evaluation.py index dbb197df0..bc49ce5eb 100644 --- a/evaluation.py +++ b/evaluation.py @@ -43,7 +43,9 @@ from mindspeed_llm.tasks.evaluation.eval_impl.ceval_exam import CEvalExam from mindspeed_llm.tasks.evaluation.eval_impl.bbh_eval import BBHEval from mindspeed_llm.tasks.evaluation.eval_impl.agi_eval import AGIEvalExam from mindspeed_llm.tasks.evaluation.eval_impl.human_eval import HumanEval -from mindspeed_llm.tasks.evaluation.eval_impl.cmmlu_eval import CmmluEval +from mindspeed_llm.tasks.evaluation.eval_impl.gpqa_eval import GPQAEval +from mindspeed_llm.tasks.evaluation.eval_impl.aime_eval import AIMEEval +from mindspeed_llm.tasks.evaluation.eval_impl.math500_eval import MATH500Eval from mindspeed_llm.tasks.evaluation.eval_impl.needlebench_eval import NeedleBenchEval sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) @@ -156,8 +158,7 @@ class LLMChat(Chat): do_sample=False, max_new_tokens=self.args.max_new_tokens, stream=False, - return_output_log_probs=return_output_log_probs, - broadcast=self.args.broadcast + return_output_log_probs=return_output_log_probs ) if getattr(self.args, "task", False) and self.args.task[0] == 'needlebench': return result, dist.get_rank() @@ -188,7 +189,8 @@ def mmlu(eval_args, agent): answer = None score_df = None for path in eval_args.task_data_path: - data_path = path + if 'mmlu' in path: + data_path = path try: if data_path: mmlu_eval = MmluEval(test_dir=data_path, eval_args=eval_args) @@ -201,27 +203,11 @@ def mmlu(eval_args, agent): return answer, score_df -def cmmlu(eval_args, agent): - data_path = None - answer = None - score_df = None - for path in eval_args.task_data_path: - data_path = path - try: - if data_path: - cmmlu_eval = CmmluEval(test_dir=data_path, eval_args=eval_args) - answer, score_df = cmmlu_eval.eval(chat=agent) - if dist.get_rank() == 0: - logger.info('\n{}'.format(score_df)) - except Exception as e: - logger.info(e) - return answer, score_df - - def needlebench(eval_args, agent): data_path = None for path in eval_args.task_data_path: - data_path = path + if 'needlebench' in path: + data_path = path try: if data_path: needlebench_eval = NeedleBenchEval(test_dir=data_path, eval_args=eval_args) @@ -237,7 +223,8 @@ def gsm8k(eval_args, agent): answer = None score_df = None for path in eval_args.task_data_path: - data_path = path + if 'gsm8k' in path: + data_path = path try: if data_path: gsm8k_eval = Gsm8kEval(test_dir=data_path, eval_args=eval_args) @@ -256,7 +243,8 @@ def boolq(eval_args, agent): score_df = None for path in eval_args.task_data_path: - data_path = path + if 'boolq' in path: + data_path = path try: if data_path: boolq_eval = BoolqEval(test_dir=data_path, eval_args=eval_args) @@ -275,7 +263,8 @@ def ceval(eval_args, agent): score_df = None for path in eval_args.task_data_path: - data_path = path + if 'ceval' in path: + data_path = path try: if data_path: ceval_exam = CEvalExam(test_dir=data_path, eval_args=eval_args) @@ -294,7 +283,8 @@ def human_eval(eval_args, agent): score_df = None for path in eval_args.task_data_path: - data_path = path + if 'human_eval' in path: + data_path = path try: if data_path: human_eval_exam = HumanEval(test_dir=data_path, eval_args=eval_args) @@ -313,7 +303,8 @@ def agi_eval(eval_args, agent): score_df = None for path in eval_args.task_data_path: - data_path = path + if 'agieval' in path: + data_path = path try: if data_path: agieval_exam = AGIEvalExam(test_dir=data_path, eval_args=eval_args) @@ -326,13 +317,72 @@ def agi_eval(eval_args, agent): return answer, score_df +def gpqa_eval(eval_args, agent): + data_path = None + answer = None + score_df = None + for path in eval_args.task_data_path: + if 'gpqa' in path: + data_path = path + try: + if data_path: + gpqa_eval = GPQAEval(test_dir=data_path, eval_args=eval_args) + answer, score_df = gpqa_eval.eval(chat=agent) + if dist.get_rank() == 0: + logger.info('\n{}'.format(score_df)) + except Exception as e: + logger.info(e) + + return answer, score_df + + +def aime_eval(eval_args, agent): + data_path = None + answer = None + score_df = None + + for path in eval_args.task_data_path: + if 'aime' in path: + data_path = path + try: + if data_path: + aime2024_eval = AIMEEval(test_dir=data_path, eval_args=eval_args) + answer, score_df = aime2024_eval.eval(chat=agent) + if dist.get_rank() == 0: + logger.info('\n{}'.format(score_df)) + except Exception as e: + logger.info(e) + + return answer, score_df + +def math500_eval(eval_args, agent): + data_path = None + answer = None + score_df = None + + for path in eval_args.task_data_path: + if 'math500' in path: + data_path = path + try: + if data_path: + math_eval = MATH500Eval(test_dir=data_path, eval_args=eval_args) + answer, score_df = math_eval.eval(chat=agent) + if dist.get_rank() == 0: + logger.info('\n{}'.format(score_df)) + except Exception as e: + logger.info(e) + + return answer, score_df + + def bbh_eval(eval_args, agent): data_path = None answer = None score_df = None for path in eval_args.task_data_path: - data_path = path + if 'bbh' in path: + data_path = path try: if data_path: bbh = BBHEval(test_dir=data_path, eval_args=eval_args) @@ -357,11 +407,8 @@ def main(): tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path, trust_remote_code=True, local_files_only=True) rank = dist.get_rank() - if 'cmmlu' in args.task: - a = time.time() - cmmlu(args, LLMChat(args, model, tokenizer)) - if rank == 0: - logger.info(f'CMMLU Running Time:, {time.time() - a}') + print('load ckpt finished', flush=True) + if 'mmlu' in args.task: a = time.time() mmlu(args, LLMChat(args, model, tokenizer)) @@ -404,6 +451,24 @@ def main(): logger.info(f'NeedleBench_eval Running Time: {time.time() - a}') + if 'gpqa_eval' in args.task: + a = time.time() + gpqa_eval(args, LLMChat(args, model, tokenizer)) + if rank == 0: + logger.info(f'gpqa_eval Running Time: {time.time() - a}') + + if 'aime_eval' in args.task: + a = time.time() + aime_eval(args, LLMChat(args, model, tokenizer)) + if rank == 0: + logger.info(f'aime_eval Running Time: {time.time() - a}') + + if 'math500_eval' in args.task: + a = time.time() + math500_eval(args, LLMChat(args, model, tokenizer)) + if rank == 0: + logger.info(f'math500_eval Running Time: {time.time() - a}') + if __name__ == "__main__": main() diff --git a/examples/mcore/qwen25_math/evaluate_qwen25_math_7b_ptd_all.sh b/examples/mcore/qwen25_math/evaluate_qwen25_math_7b_ptd_all.sh new file mode 100644 index 000000000..6dca073d6 --- /dev/null +++ b/examples/mcore/qwen25_math/evaluate_qwen25_math_7b_ptd_all.sh @@ -0,0 +1,74 @@ +#!/bin/bash +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +# Change for multinode config +MASTER_ADDR=localhost +MASTER_PORT=7000 +NNODES=1 +NODE_RANK=0 +NPUS_PER_NODE=1 +WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES)) + +# please fill these path configurations +CHECKPOINT=/data/00805503/model_from_hf/qwen2.5_math_mcore/ +TOKENIZER_PATH=/data/00805503/model_from_hf/qwen2.5_math_hf/ +DATA_PATH="/efs_gy1/yangziliang/data/tasks/math500 /efs_gy1/yangziliang/data/tasks/gpqa /efs_gy1/yangziliang/data/tasks/aime" +TASK="math500_eval gpqa_eval aime_eval" + + +TP=1 +PP=1 +SEQ_LENGTH=4096 + +DISTRIBUTED_ARGS=" + --nproc_per_node $NPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT +" + +# Different task needs different max_new_tokens value, please follow the instruction in readme. +torchrun $DISTRIBUTED_ARGS evaluation.py \ + --evaluation-batch-size 2 \ + --use-mcore-models \ + --task ${TASK} \ + --task-data-path $DATA_PATH \ + --no-chat-template \ + --tensor-model-parallel-size ${TP} \ + --pipeline-model-parallel-size ${PP} \ + --num-layers 28 \ + --max-length 4096 \ + --max-new-tokens 2048 \ + --hidden-size 3584 \ + --ffn-hidden-size 18944 \ + --num-attention-heads 28 \ + --max-position-embeddings ${SEQ_LENGTH} \ + --seq-length ${SEQ_LENGTH} \ + --disable-bias-linear \ + --add-qkv-bias \ + --group-query-attention \ + --num-query-groups 4 \ + --swiglu \ + --use-fused-swiglu \ + --normalization RMSNorm \ + --norm-epsilon 1e-6 \ + --use-fused-rmsnorm \ + --position-embedding-type rope \ + --rotary-base 10000 \ + --use-fused-rotary-pos-emb \ + --make-vocab-size-divisible-by 1 \ + --padded-vocab-size 152064 \ + --micro-batch-size 1 \ + --tokenizer-type PretrainedFromHF \ + --tokenizer-name-or-path ${TOKENIZER_PATH} \ + --tokenizer-not-use-fast \ + --untie-embeddings-and-output-weights \ + --no-gradient-accumulation-fusion \ + --attention-softmax-in-fp32 \ + --seed 42 \ + --load ${CHECKPOINT} \ + --exit-on-missing-checkpoint \ + --no-load-rng \ + --no-load-optim \ + | tee logs/evaluation_mcore_qwen25_math_7b_all.log diff --git a/examples/mcore/qwen25_math/evaluate_qwen25_math_7b_ptd_single.sh b/examples/mcore/qwen25_math/evaluate_qwen25_math_7b_ptd_single.sh new file mode 100644 index 000000000..714aed3ae --- /dev/null +++ b/examples/mcore/qwen25_math/evaluate_qwen25_math_7b_ptd_single.sh @@ -0,0 +1,76 @@ +#!/bin/bash +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +# Change for multinode config +MASTER_ADDR=localhost +MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 +NPUS_PER_NODE=1 +WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES)) + +# please fill these path configurations +CHECKPOINT=/data/00805503/model_from_hf/qwen2.5_math_mcore/ +TOKENIZER_PATH=/data/00805503/model_from_hf/qwen2.5_math_hf/ +DATA_PATH=/efs_gy1/yangziliang/data/tasks/gpqa/ +# /efs_gy1/yangziliang/data/tasks/aime /efs_gy1/yangziliang/data/tasks/math500 /efs_gy1/yangziliang/data/tasks/gpqa +TASK="gpqa_eval" +#aime_eval math500_eval gpqa_eval + + +TP=1 +PP=1 +SEQ_LENGTH=4096 + +DISTRIBUTED_ARGS=" + --nproc_per_node $NPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT +" + +# Different task needs different max_new_tokens value, please follow the instruction in readme. +torchrun $DISTRIBUTED_ARGS evaluation.py \ + --evaluation-batch-size 1 \ + --use-mcore-models \ + --task ${TASK} \ + --task-data-path $DATA_PATH \ + --no-chat-template \ + --tensor-model-parallel-size ${TP} \ + --pipeline-model-parallel-size ${PP} \ + --num-layers 28 \ + --max-length 4096 \ + --max-new-tokens 2048 \ + --hidden-size 3584 \ + --ffn-hidden-size 18944 \ + --num-attention-heads 28 \ + --max-position-embeddings ${SEQ_LENGTH} \ + --seq-length ${SEQ_LENGTH} \ + --disable-bias-linear \ + --add-qkv-bias \ + --group-query-attention \ + --num-query-groups 4 \ + --swiglu \ + --use-fused-swiglu \ + --normalization RMSNorm \ + --norm-epsilon 1e-6 \ + --use-fused-rmsnorm \ + --position-embedding-type rope \ + --rotary-base 10000 \ + --use-fused-rotary-pos-emb \ + --make-vocab-size-divisible-by 1 \ + --padded-vocab-size 152064 \ + --micro-batch-size 1 \ + --tokenizer-type PretrainedFromHF \ + --tokenizer-name-or-path ${TOKENIZER_PATH} \ + --tokenizer-not-use-fast \ + --untie-embeddings-and-output-weights \ + --no-gradient-accumulation-fusion \ + --attention-softmax-in-fp32 \ + --seed 42 \ + --load ${CHECKPOINT} \ + --exit-on-missing-checkpoint \ + --no-load-rng \ + --no-load-optim \ +# | tee logs/evaluation_mcore_qwen25_math_7b_${TASK}.log diff --git a/mindspeed_llm/tasks/evaluation/eval_impl/aime_eval.py b/mindspeed_llm/tasks/evaluation/eval_impl/aime_eval.py new file mode 100644 index 000000000..a68b80f08 --- /dev/null +++ b/mindspeed_llm/tasks/evaluation/eval_impl/aime_eval.py @@ -0,0 +1,167 @@ +# coding=utf-8 +# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import logging +import json +import random +import re +import tqdm +import pandas as pd + +from torch import distributed as dist +from mindspeed_llm.tasks.evaluation.eval_api.dataset_eval import DatasetEval +from mindspeed_llm.tasks.evaluation.eval_api.chat import Chat +from mindspeed_llm.tasks.utils.error_utils import check_divisible_by_zero + +logger = logging.getLogger(__name__) + + +def base_prompt(example: dict) -> str: + prompt = f"What is the correct answer to this question: {example['Problem']}" + return prompt + + +def zero_shot_prompt(example: dict) -> str: + prompt = base_prompt(example) + prompt += f"\n\nPlease reason step by step before answering, and put your final answer within \\boxed{{}}." + return prompt + + +def create_prompts(examples: list[dict], prompt_type: str = 'zero_shot') -> tuple[list[str], list[dict]]: + if prompt_type == 'zero_shot': + return [zero_shot_prompt(example) for example in examples], examples + else: + raise ValueError(f"Prompt type {prompt_type} not supported.") + + +def load_jsonl(file_path: str) -> list[dict]: + if not file_path: + return None + + lines = [] + with open(file_path, 'r', encoding='utf-8') as file: + for line_num, line in enumerate(file, start=1): + try: + lines.append(json.loads(line)) + except json.JSONDecodeError as e: + logger.warning(f'{line_num}, {e}') + return lines + + +def read_parquet_to_dict_list(filepath: str) -> list[dict]: + df = pd.read_parquet(filepath) + return df.to_dict(orient='records') + + +class AIMEEval(DatasetEval): + """ + NOTE: AIME2024 datesets from https://huggingface.co/datasets/Maxwell-Jia/AIME_2024, the table header is not the + same with opencompass('origin_prompt', 'gold_answer'). + """ + + def __init__(self, test_dir, eval_args, output_template=r"(?i)\\boxed{([0-9.]*)}"): + self.test_dir = test_dir + self.instruction_template = eval_args.instruction_template + self.batch_size = eval_args.evaluation_batch_size + self.output_template = output_template + self.rank = dist.get_rank() + self.file_pbar = None + self.task_pbar = None + + def eval(self, chat: Chat) -> (dict, pd.DataFrame): + answer_result = {} + score_datas = [] + total_acc_n = 0 + total_n = 0 + + if self.rank == 0: + self.file_pbar = tqdm.tqdm(total=len(os.listdir(self.test_dir)), desc="total datafiles") + for file in os.listdir(self.test_dir): + file_path = os.path.join(self.test_dir, file) + if not os.path.exists(file_path): + raise FileExistsError("The file ({}) does not exist !".format(file_path)) + + question_df = read_parquet_to_dict_list(file_path) + prompts, question_list = create_prompts(question_df) + + ans, acc_n = self.__score(chat, prompts, question_list) + answer_result.update(ans) + + if self.rank == 0: + total_n += len(question_list) + total_acc_n += acc_n + + if self.task_pbar is not None: + self.task_pbar.close() + + if self.file_pbar is not None: + self.file_pbar.update() + + if self.rank == 0: + logger.info(f"aime2024 acc = {total_acc_n}/{total_n}={check_divisible_by_zero(total_acc_n, total_n)}") + score_datas.append(["total", total_n, total_acc_n / total_n]) + score_df = pd.DataFrame(columns=['subject', 'question_n', 'acc'], data=score_datas) + return answer_result, score_df + + def top_k_eval(self, ) -> (dict, pd.DataFrame): + pass + + def __score(self, chat: Chat, prompts, question_list) -> tuple[dict, int]: + acc_n = 0 + answer_result = {} + instructions = [] + corrects = [] + if self.rank == 0: + self.task_pbar = tqdm.tqdm(total=len(question_list), desc="questions", leave=False) + idx = 0 + for question_id, (prompt, question) in enumerate(zip(prompts, question_list)): + if self.task_pbar is not None: + self.task_pbar.update() + instructions.append(prompt) + corrects.append(str(question["Answer"])) + if len(instructions) == self.batch_size or len(prompts) == idx + 1: + try: + chat_results, rank = chat.chat(instruction=instructions, history=[]) + if chat_results: + for index, chat_result in enumerate(chat_results): + if isinstance(chat_result, list): + chat_result = chat_result[0] + ans_correct_str = f"Correct answer: {corrects[index]}\nChosen answer: " \ + f"{chat_result}\n, rank: {self.rank} " + logger.info(ans_correct_str) + match = re.search(self.output_template, chat_result, flags=re.DOTALL + re.MULTILINE) + match_flag = False + extracted_answer = '' + if match: + extracted_answer = match.group(1) if match else None + if extracted_answer.strip() == corrects[index].strip(): + logger.info(f'the {question_id}\'th question pass, ai answer {extracted_answer}') + answer_result[question_id] = extracted_answer + acc_n += 1 + match_flag = True + + if not match_flag: + logger.info(f'the {question_id}\'th question not match, ai answer {extracted_answer}') + else: + logger.warning(f'the {question_id}\'th question does not get a chat result.') + except Exception as e: + logger.warning(f'{question_id} error occurs: {str(e)}.') + + instructions = [] + corrects = [] + + return answer_result, acc_n + diff --git a/mindspeed_llm/tasks/evaluation/eval_impl/gpqa_eval.py b/mindspeed_llm/tasks/evaluation/eval_impl/gpqa_eval.py new file mode 100644 index 000000000..f32ae0d8f --- /dev/null +++ b/mindspeed_llm/tasks/evaluation/eval_impl/gpqa_eval.py @@ -0,0 +1,162 @@ +# coding=utf-8 +# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import logging +import json +import random +import re +from collections import namedtuple +import tqdm +import pandas as pd + + +from torch import distributed as dist +from mindspeed_llm.tasks.evaluation.eval_api.dataset_eval import DatasetEval +from mindspeed_llm.tasks.evaluation.eval_api.chat import Chat +from mindspeed_llm.tasks.utils.error_utils import check_divisible_by_zero + +logger = logging.getLogger(__name__) +Example = namedtuple('Example', ['question', 'choice1', 'choice2', 'choice3', 'choice4', 'correct_answer']) +ANSWER_LIST = ['A', 'B', 'C', 'D'] + + +def base_prompt(example: Example) -> str: + prompt = f"What is the correct answer to this question: {example.question}" + prompt += f"\n\nChoices:\n(A) {example.choice1}\n(B) {example.choice2}" \ + f"\n(C) {example.choice3}\n(D) {example.choice4}" + return prompt + + +def zero_shot_prompt(example: Example) -> str: + prompt = base_prompt(example) + prompt += f"\n\nThink step by step before answering. put your final answer within \\boxed{{LETTER}}. where LETTER is one of ABCD. " + return prompt + + +def create_prompts(examples: list[Example], prompt_type: str = 'zero_shot') -> tuple[list[str], list[Example]]: + if prompt_type == 'zero_shot': + return [zero_shot_prompt(example) for example in examples], examples + else: + raise ValueError(f"Prompt type {prompt_type} not supported.") + + +def shuffle_choices_and_create_example(row) -> Example: + list_choices = [row['Incorrect Answer 1'], row['Incorrect Answer 2'], row['Incorrect Answer 3'], + row['Correct Answer']] + random.shuffle(list_choices) + example = Example(row.Question, list_choices[0], list_choices[1], list_choices[2], list_choices[3], + ANSWER_LIST[list_choices.index(row['Correct Answer'])]) + return example + + +class GPQAEval(DatasetEval): + def __init__(self, test_dir, eval_args, output_template=r"(?i)\\boxed{([A-D]?)}"): + self.test_dir = test_dir + self.instruction_template = eval_args.instruction_template + self.batch_size = eval_args.evaluation_batch_size + self.output_template = output_template + self.rank = dist.get_rank() + self.file_pbar = None + self.task_pbar = None + + def eval(self, chat: Chat) -> (dict, pd.DataFrame): + answer_result = {} + score_datas = [] + total_acc_n = 0 + total_n = 0 + + if self.rank == 0: + self.file_pbar = tqdm.tqdm(total=len(os.listdir(self.test_dir)), desc="total_datafiles") + + for file in os.listdir(self.test_dir): + file_path = os.path.join(self.test_dir, file) + if not os.path.exists(file_path): + raise FileExistsError("The file ({}) does not exist !".format(file_path)) + + question_df = pd.read_csv(os.path.join(file_path), delimiter=',') + question_list = [shuffle_choices_and_create_example(row) for _, row in question_df.iterrows()] + prompts, question_list = create_prompts(question_list) + + ans, acc_n = self.__score(chat, prompts, question_list) + answer_result.update(ans) + + if self.rank == 0: + total_n += len(question_list) + total_acc_n += acc_n + + if self.task_pbar is not None: + self.task_pbar.close() + + if self.file_pbar is not None: + self.file_pbar.update() + + if self.rank == 0: + logger.info(f"gpqa acc = {total_acc_n}/{total_n}={check_divisible_by_zero(total_acc_n, total_n)}") + score_datas.append(["total", total_n, total_acc_n / total_n]) + score_df = pd.DataFrame(columns=['subject', 'question_n', 'acc'], data=score_datas) + return answer_result, score_df + + def top_k_eval(self, ) -> (dict, pd.DataFrame): + pass + + def __score(self, chat:Chat, prompts, question_list) -> int: + acc_n = 0 + answer_result = {} + instructions = [] + corrects = [] + if self.rank == 0: + self.task_pbar = tqdm.tqdm(total=len(question_list), desc="questions", leave=False) + idx = 0 + for question_id, (prompt, question) in enumerate(zip(prompts, question_list)): + if self.task_pbar is not None: + self.task_pbar.update() + instructions.append(prompt) + corrects.append(question.correct_answer) + + if len(instructions) == self.batch_size or len(prompts) == idx + 1: + try: + chat_results, rank = chat.chat(instruction=instructions, history=[]) + if chat_results: + for index, chat_result in enumerate(chat_results): + if isinstance(chat_result, list): + chat_result = chat_result[0] + ans_correct_str = f"Correct answer: {corrects[index]}\nChosen answer: " \ + f"{chat_result}\n, rank: {self.rank} " + logger.info(ans_correct_str) + match = re.search(self.output_template, chat_result, flags=re.DOTALL + re.MULTILINE) + match_flag = False + if match: + extracted_answer = match.group(1) if match else None + if extracted_answer.strip() == corrects[index].strip(): + logger.info(f'the {question_id}\'th question pass, ai answer: {extracted_answer}') + answer_result[question_id] = extracted_answer + acc_n += 1 + match_flag = True + + if not match_flag: + logger.info(f'{question_id} not match, corrects answer {corrects[index]}') + else: + logger.warning(f'{question_id} does not get a chat result.') + except Exception as e: + logger.warning(f'{question_id} error occurs: {str(e)}.') + + instructions = [] + corrects = [] + + idx += 1 + + return answer_result, acc_n + diff --git a/mindspeed_llm/tasks/evaluation/eval_impl/math500_eval.py b/mindspeed_llm/tasks/evaluation/eval_impl/math500_eval.py new file mode 100644 index 000000000..438afa04e --- /dev/null +++ b/mindspeed_llm/tasks/evaluation/eval_impl/math500_eval.py @@ -0,0 +1,190 @@ +# coding=utf-8 +# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import logging +import json +import re +import tqdm +import pandas as pd + +from torch import distributed as dist +from mindspeed_llm.tasks.evaluation.eval_api.dataset_eval import DatasetEval +from mindspeed_llm.tasks.evaluation.eval_api.chat import Chat +from mindspeed_llm.tasks.utils.error_utils import check_divisible_by_zero + +logger = logging.getLogger(__name__) + + +def base_prompt(example: dict) -> str: + prompt = f"What is the correct answer to this question: {example['problem']}" + return prompt + + +def zero_shot_prompt(example: dict) -> str: + prompt = base_prompt(example) + prompt += f"\n\nPlease reason step by step before answering, and put your final answer within \\boxed{{}}." + return prompt + + +def create_prompts(examples: list[dict], prompt_type: str = 'zero_shot') -> tuple[list[str], list[dict]]: + if prompt_type == 'zero_shot': + return [zero_shot_prompt(example) for example in examples], examples + else: + raise ValueError(f"Prompt type {prompt_type} not supported.") + + +def is_all_chars_in_another(s1, s2): + for char in s1: + if char not in s2: + return False + return True + + +def bottom_up_dp_lcs(str_a, str_b): + """ + longest common substring of str_a and str_b + """ + if len(str_a) == 0 or len(str_b) == 0: + return True + dp = [[0 for _ in range(len(str_b) + 1)] for _ in range(len(str_a) + 1)] + max_len = 0 + for i in range(1, len(str_a) + 1): + for j in range(1, len(str_b) + 1): + if str_a[i-1] == str_b[j-1]: + dp[i][j] = dp[i-1][j-1] + 1 + max_len = max([max_len, dp[i][j]]) + else: + dp[i][j] = 0 + + if max_len > 0.7*len(str_a) or max_len > 0.7*len(str_b): + return True + return False + + +def load_jsonl(file_path: str) -> list[dict]: + if not file_path: + return None + + lines = [] + with open(file_path, 'r', encoding='utf-8') as file: + for line_num, line in enumerate(file, start=1): + try: + lines.append(json.loads(line)) + except json.JSONDecodeError as e: + logger.warning(f'{line_num}, {e}') + return lines + + +class MATH500Eval(DatasetEval): + """ + NOTE: datesets from https://huggingface.co/datasets/HuggingFaceH4/MATH-500/tree/main + """ + + def __init__(self, test_dir, eval_args, output_template=r"(?i)\\boxed{(.*)}"): + self.test_dir = test_dir + self.instruction_template = eval_args.instruction_template + self.output_template = output_template + self.batch_size = eval_args.evaluation_batch_size + self.rank = dist.get_rank() + self.file_pbar = None + self.task_pbar = None + + def eval(self, chat: Chat) -> (dict, pd.DataFrame): + answer_result = {} + score_datas = [] + total_acc_n = 0 + total_n = 0 + + if self.rank == 0: + self.file_pbar = tqdm.tqdm(total=len(os.listdir(self.test_dir)), desc="total datafiles") + for file in os.listdir(self.test_dir): + file_path = os.path.join(self.test_dir, file) + if not os.path.exists(file_path): + raise FileExistsError("The file ({}) does not exist !".format(file_path)) + + question_df = load_jsonl(file_path) + prompts, question_list = create_prompts(question_df) + + ans, acc_n = self.__score(chat, prompts, question_list) + if ans: + answer_result.update(ans) + + if self.rank == 0: + total_n += len(question_list) + total_acc_n += acc_n + + if self.task_pbar is not None: + self.task_pbar.close() + + if self.file_pbar is not None: + self.file_pbar.update() + + if self.rank == 0: + logger.info(f"match500 acc = {total_acc_n}/{total_n}={check_divisible_by_zero(total_acc_n, total_n)}") + score_datas.append(["total", total_n, total_acc_n / total_n]) + score_df = pd.DataFrame(columns=['subject', 'question_n', 'acc'], data=score_datas) + + return answer_result, score_df + + def top_k_eval(self, ) -> (dict, pd.DataFrame): + pass + + def __score(self, chat: Chat, prompts, question_list) -> tuple[dict, int]: + answer_result = {} + acc_n = 0 + instructions = [] + corrects = [] + if self.rank == 0: + self.task_pbar = tqdm.tqdm(total=len(question_list), desc="questions", leave=False) + idx = 0 + for question_id, (prompt, question) in enumerate(zip(prompts, question_list)): + if self.task_pbar is not None: + self.task_pbar.update() + instructions.append(prompt) + corrects.append(question["answer"]) + if len(instructions) == self.batch_size or len(prompts) == idx + 1: + try: + chat_results, rank = chat.chat(instructions, history=[]) + if chat_results: + for index, chat_result in enumerate(chat_results): + if isinstance(chat_result, list): + chat_result = chat_result[0] + ans_correct_str = f"Correct answer: {corrects[index]}\nChosen answer: " \ + f"{chat_result}\n, rank: {self.rank} " + logger.info(ans_correct_str) + match = re.search(self.output_template, chat_result, flags=re.DOTALL+re.MULTILINE) + + match_flag = False + extracted_answer = '' + if match: + extracted_answer = match.group(1) if match else None + if bottom_up_dp_lcs(corrects[index], extracted_answer.strip()): + logger.info(f'the {question_id}\'th question pass, ai answer {extracted_answer}') + answer_result[question_id] = extracted_answer + acc_n += 1 + match_flag = True + + if not match_flag: + logger.info(f'the {question_id}\'th question not pass, ai answer {extracted_answer}.') + else: + logger.warning(f'the {question_id}\'th question does not get a chat result.') + except Exception as e: + logger.error(f'error occurs: {str(e)}') + + instructions = [] + corrects = [] + + return answer_result, acc_n -- Gitee