From 5ef0e1c6e72ef7f8f19f75dd8745b367e135e08c Mon Sep 17 00:00:00 2001 From: BernardLee Date: Sat, 1 Nov 2025 15:17:10 +0800 Subject: [PATCH] fix some bug for datasets --- .../benchmark/datasets/agieval/agieval.py | 7 ++-- .../datasets/agieval/dataset_loader.py | 39 ++++++++++++------- .../benchmark/datasets/agieval/evaluation.py | 2 +- .../datasets/agieval/post_process.py | 2 +- ais_bench/benchmark/datasets/bfcl/bfcl.py | 25 +++++++++++- 5 files changed, 54 insertions(+), 21 deletions(-) diff --git a/ais_bench/benchmark/datasets/agieval/agieval.py b/ais_bench/benchmark/datasets/agieval/agieval.py index 58ecfd5445..f08e37c03f 100644 --- a/ais_bench/benchmark/datasets/agieval/agieval.py +++ b/ais_bench/benchmark/datasets/agieval/agieval.py @@ -18,10 +18,10 @@ class AGIEvalDataset(BaseDataset): @staticmethod def load(path: str, name: str, setting_name: str): + # 先验证setting_name,再获取路径 + assert setting_name == 'zero-shot', 'only support zero-shot setting' path = get_data_path(path) from .dataset_loader import load_dataset, load_dataset_as_result_schema - - assert setting_name in 'zero-shot', 'only support zero-shot setting' dataset_wo_label = load_dataset(name, setting_name, path) dataset_with_label = load_dataset_as_result_schema(name, path) dataset = [] @@ -40,8 +40,9 @@ class AGIEvalDataset_v2(BaseDataset): @staticmethod def load(path: str, name: str, setting_name: str): + # 先验证setting_name,再获取路径 + assert setting_name == 'zero-shot', 'only support zero-shot setting' path = get_data_path(path) - assert setting_name in 'zero-shot', 'only support zero-shot setting' if environ.get('DATASET_SOURCE') == 'ModelScope': from modelscope import MsDataset diff --git a/ais_bench/benchmark/datasets/agieval/dataset_loader.py b/ais_bench/benchmark/datasets/agieval/dataset_loader.py index 75d90599b5..a0b94f6f3e 100644 --- a/ais_bench/benchmark/datasets/agieval/dataset_loader.py +++ b/ais_bench/benchmark/datasets/agieval/dataset_loader.py @@ -30,34 +30,42 @@ math_output_datasets = ['gaokao-mathcloze', 'math'] def convert_zero_shot(line, dataset_name): try: - passage = line['passage'] if line['passage'] is not None else '' + # 使用get方法安全访问字典键,避免KeyError + passage = line.get('passage', '') if line.get('passage') is not None else '' if dataset_name in english_qa_datasets: option_string = 'ABCDEFG' - count = len(line['options']) + options = line.get('options', []) + count = len(options) if count == 1: count = 5 - return passage + 'Q: ' + line['question'] + ' ' \ - + 'Answer Choices: ' + ' '.join(line['options']) + '\n' + \ + return passage + 'Q: ' + line.get('question', '') + ' ' \ + + 'Answer Choices: ' + ' '.join(options) + '\n' + \ 'A: Among A through {}, the answer is'.format(option_string[count - 1]) elif dataset_name in chinese_qa_datasets: option_string = 'ABCDEFG' - count = len(line['options']) + options = line.get('options', []) + count = len(options) if count == 1: count = 4 - return passage + '问题:' + line['question'] + ' ' \ - + '选项:' + ' '.join(line['options']) + '\n' + \ + return passage + '问题:' + line.get('question', '') + ' ' \ + + '选项:' + ' '.join(options) + '\n' + \ '答案:从A到{}, 我们应选择'.format(option_string[count - 1]) elif dataset_name in english_cloze_datasets: - return passage + 'Q: ' + line['question'] + '\n' \ + return passage + 'Q: ' + line.get('question', '') + '\n' \ 'A: The answer is' elif dataset_name in chinese_cloze_datasets: - return passage + '问题:' + line['question'] + '\n' \ + return passage + '问题:' + line.get('question', '') + '\n' \ '答案:' - except NameError: - print('Dataset not defined.') + else: + # 如果数据集不匹配任何已知类型,返回一个默认值 + return passage + line.get('question', '') + except Exception as e: + print(f'Error in convert_zero_shot: {e}') + # 在出现异常时返回一个默认值而不是None + return line.get('passage', '') + line.get('question', '') prefix = '该问题为单选题,所有选项中必有一个正确答案,且只有一个正确答案。\n' @@ -240,9 +248,10 @@ def concat_prompt_chat_mode(demos, def convert_few_shot(line, dataset_name, demo, n_shot, chat_mode=False): - passage = line['passage'] if line['passage'] is not None else '' + # 使用get方法安全访问字典键,避免KeyError + passage = line.get('passage', '') if line.get('passage') is not None else '' question = line['question'] - options = line['options'] if line['options'] is not None else '' + options = line.get('options', []) if line.get('options') is not None else [] if dataset_name in english_qa_datasets: question_input = 'Problem {}. '.format(n_shot + 1) + passage + ' ' + question + '\n' \ @@ -377,11 +386,13 @@ def load_dataset_as_result_schema(dataset_name, parent_path): processed = [] for i, line in enumerate(loaded_jsonl): problem_input = convert_zero_shot(line, dataset_name) + # 安全地访问'label'和'answer'键,避免KeyError + label = line.get('label') if line.get('label') else line.get('answer') processed.append( ResultsForHumanSchema( index=i, problem_input=problem_input, - label=line['label'] if line['label'] else line['answer'], + label=label, )) return processed diff --git a/ais_bench/benchmark/datasets/agieval/evaluation.py b/ais_bench/benchmark/datasets/agieval/evaluation.py index c5a9916a11..cb2dc3b83b 100644 --- a/ais_bench/benchmark/datasets/agieval/evaluation.py +++ b/ais_bench/benchmark/datasets/agieval/evaluation.py @@ -9,7 +9,7 @@ def convert_to_set(item): if isinstance(item, str): return {item} if item is None: - return {} + return set() raise ValueError("Input can't parse:", item) diff --git a/ais_bench/benchmark/datasets/agieval/post_process.py b/ais_bench/benchmark/datasets/agieval/post_process.py index ed3eb463ae..3fa28e1438 100644 --- a/ais_bench/benchmark/datasets/agieval/post_process.py +++ b/ais_bench/benchmark/datasets/agieval/post_process.py @@ -80,7 +80,7 @@ def find_first_capital_letter(answer): def extract_answer_in_bracket(answer, prefix='【', suffix='】'): - if prefix not in answer and suffix not in answer: + if prefix not in answer or suffix not in answer: # print("doesn't found special tokens in:", answer) return '' s = answer.index(prefix) + len(prefix) diff --git a/ais_bench/benchmark/datasets/bfcl/bfcl.py b/ais_bench/benchmark/datasets/bfcl/bfcl.py index 29b04c1317..e8896937c7 100644 --- a/ais_bench/benchmark/datasets/bfcl/bfcl.py +++ b/ais_bench/benchmark/datasets/bfcl/bfcl.py @@ -6,6 +6,16 @@ from os import environ from datasets import Dataset from .bfcl_dependency import * + +# Fallback implementations for BFCL dependency functions +# in case the BFCL package is not installed +def is_java(category): + """Check if the category is Java-related.""" + return 'java' in category.lower() + +def is_js(category): + """Check if the category is JavaScript-related.""" + return 'javascript' in category.lower() or 'js' in category.lower() from ..base import BaseDataset from ais_bench.benchmark.openicl.icl_evaluator import BaseEvaluator @@ -337,22 +347,33 @@ class BFCLMultiTurnEvaluator(BFCLEvaluator): where the model needs to maintain context across multiple interactions. """ - def decode_execute(self, result, is_fc_model=True): + def decode_execute(self, result, is_fc_model=None): """ Decode execution result from model output. Args: result: Model output to decode - is_fc_model: Whether the model supports function calling format + is_fc_model: Whether the model supports function calling format. If None, use self.is_fc_model Returns: List of executable function calls """ + # 如果没有指定is_fc_model,使用实例的is_fc_model属性 + if is_fc_model is None: + is_fc_model = self.is_fc_model + if is_fc_model: if isinstance(result, str): result = json.loads(result) return convert_to_function_call(result) else: + # 对于Prompting模型,也需要检查并处理字符串类型的输入 + if isinstance(result, str): + try: + result = json.loads(result) + except json.JSONDecodeError: + # 如果不是有效的JSON字符串,就保持原样 + pass return default_decode_execute_prompting(result) def score(self, predictions, references, test_set): -- Gitee