diff --git a/fastSum/Dataloader/summarizationLoader.py b/fastSum/Dataloader/summarizationLoader.py index ccaa16f7da202f2126f8dcb71a6931f46c7c31e7..74d0763ecef54444dae036dc1d3affe5f8ddeb75 100644 --- a/fastSum/Dataloader/summarizationLoader.py +++ b/fastSum/Dataloader/summarizationLoader.py @@ -1,4 +1,4 @@ -from typing import Union, Dict +from typing import Union, Dict, Optional import os import random @@ -29,12 +29,13 @@ class SumLoader(JsonLoader): 所有摘要数据集loader的父类 """ - def __init__(self): - fields = { - 'text': 'text', - 'summary': 'summary', - 'label': Const.TARGET - } + def __init__(self, fields: Optional[Dict[str, str]] = None): + if fields is None: + fields = { + 'text': 'text', + 'summary': 'summary', + 'label': Const.TARGET + } super(SumLoader, self).__init__(fields=fields) def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: @@ -44,6 +45,11 @@ class SumLoader(JsonLoader): default_cache_path = get_cache_path() url = _get_dataset_url(dataset_name, DATASET_DIR) output_dir = cached_path(url_or_filename=url, cache_dir=default_cache_path, name='dataset') + # https://gitee.com/fastnlp/fastNLP/blob/7b4e099c5267efb6a4a88b9d789a0940be05bb56/fastNLP/io/file_utils.py#L201 + # 如果只有一个文件, get_filepath 返回 filepath + filename + # os.path.dirname 反向处理 + if os.path.isfile(output_dir): + output_dir = os.path.dirname(output_dir) return output_dir @@ -356,7 +362,13 @@ class AMILoader(SumLoader): """ def __init__(self, valid_ratio=0.05, test_ratio=0.05): - super(AMILoader, self).__init__() + # AMI 没有 label + fields = { + 'text': 'text', + 'summary': 'summary', + } + super(AMILoader, self).__init__(fields) + self.valid_ratio = valid_ratio self.test_ratio = test_ratio @@ -398,7 +410,12 @@ class ICSILoader(SumLoader): """ def __init__(self, valid_ratio=0.05, test_ratio=0.05): - super(ICSILoader, self).__init__() + # ICSI 没有 label + fields = { + 'text': 'text', + 'summary': 'summary', + } + super(ICSILoader, self).__init__(fields) self.valid_ratio = valid_ratio self.test_ratio = test_ratio @@ -431,7 +448,7 @@ class ICSILoader(SumLoader): return data_bundle -def _split_set(dataset_name, data_dir, split_name1="dev", split_name2="train", ratio=0.0, suffix='jsonl'): +def _split_set(dataset_name, data_dir, split_name1="dev", split_name2="train", ratio=0.0, suffix='jsonl', keep_orig: bool = True): if ratio == 0: os.renames(os.path.join(data_dir, f'{dataset_name}.{suffix}'), os.path.join(data_dir, f'{split_name2}.{suffix}')) @@ -449,7 +466,10 @@ def _split_set(dataset_name, data_dir, split_name1="dev", split_name2="train", r f2.write(line) else: f1.write(line) - os.remove(os.path.join(data_dir, f'{dataset_name}.{suffix}')) + if keep_orig: + assert split_name1 != dataset_name and split_name2 != dataset_name + else: + os.remove(os.path.join(data_dir, f'{dataset_name}.{suffix}')) os.renames(os.path.join(data_dir, f'middle_file.{suffix}'), os.path.join(data_dir, f'{split_name2}.{suffix}')) finally: