From ce87dcd82bedd5ac49f2eedc9cf5286ea4ac0719 Mon Sep 17 00:00:00 2001 From: yjqiang <8900942+yjqiang@users.noreply.github.com> Date: Wed, 3 Mar 2021 18:58:53 +0800 Subject: [PATCH] change the download() method and fix paths type in summarizationLoader.py --- fastSum/Dataloader/summarizationLoader.py | 77 +++++++++++++++-------- 1 file changed, 51 insertions(+), 26 deletions(-) diff --git a/fastSum/Dataloader/summarizationLoader.py b/fastSum/Dataloader/summarizationLoader.py index 74d0763..65d6976 100644 --- a/fastSum/Dataloader/summarizationLoader.py +++ b/fastSum/Dataloader/summarizationLoader.py @@ -1,6 +1,7 @@ -from typing import Union, Dict, Optional +from typing import Dict, Optional import os import random +from pathlib import Path from fastNLP.io.loader import JsonLoader from fastNLP.io.data_bundle import DataBundle @@ -29,6 +30,8 @@ class SumLoader(JsonLoader): 所有摘要数据集loader的父类 """ + DATASET_NAME = None # 对应 DATASET_DIR 中的 key + def __init__(self, fields: Optional[Dict[str, str]] = None): if fields is None: fields = { @@ -38,12 +41,12 @@ class SumLoader(JsonLoader): } super(SumLoader, self).__init__(fields=fields) - def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: + def load(self, paths: Optional[Path] = None) -> DataBundle: pass - def download(self, dataset_name): + def download(self): default_cache_path = get_cache_path() - url = _get_dataset_url(dataset_name, DATASET_DIR) + url = _get_dataset_url(self.DATASET_NAME, DATASET_DIR) output_dir = cached_path(url_or_filename=url, cache_dir=default_cache_path, name='dataset') # https://gitee.com/fastnlp/fastNLP/blob/7b4e099c5267efb6a4a88b9d789a0940be05bb56/fastNLP/io/file_utils.py#L201 # 如果只有一个文件, get_filepath 返回 filepath + filename @@ -61,12 +64,14 @@ class CNNDMLoader(SumLoader): https://www.aclweb.org/anthology/K16-1028/ """ + DATASET_NAME = "cnndm" + def __init__(self): super(CNNDMLoader, self).__init__() - def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: + def load(self, paths: Optional[Path] = None) -> DataBundle: if paths is None: - paths = self.download("cnndm") + paths = self.download() _paths = {} if paths: @@ -93,12 +98,14 @@ class ArxivLoader(SumLoader): https://arxiv.org/abs/1804.05685 """ + DATASET_NAME = "arxiv" + def __init__(self): super(ArxivLoader, self).__init__() - def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: + def load(self, paths: Optional[Path] = None) -> DataBundle: if paths is None: - paths = self.download("arxiv") + paths = self.download() _paths = {} if paths: @@ -125,12 +132,14 @@ class BillSumLoader(SumLoader): https://arxiv.org/abs/1910.00523 """ + DATASET_NAME = "billsum" + def __init__(self): super(BillSumLoader, self).__init__() - def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: + def load(self, paths: Optional[Path] = None) -> DataBundle: if paths is None: - paths = self.download("billsum") + paths = self.download() _paths = {} if paths: @@ -157,12 +166,14 @@ class MultiNewsLoader(SumLoader): https://arxiv.org/abs/1906.01749 """ + DATASET_NAME = "multi-news" + def __init__(self): super(MultiNewsLoader, self).__init__() - def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: + def load(self, paths: Optional[Path] = None) -> DataBundle: if paths is None: - paths = self.download("multi-news") + paths = self.download() _paths = {} if paths: @@ -189,12 +200,14 @@ class PubmedLoader(SumLoader): https://arxiv.org/abs/1804.05685 """ + DATASET_NAME = "pubmed" + def __init__(self): super(PubmedLoader, self).__init__() - def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: + def load(self, paths: Optional[Path] = None) -> DataBundle: if paths is None: - paths = self.download("pubmed") + paths = self.download() _paths = {} if paths: @@ -221,12 +234,14 @@ class SAMSumLoader(SumLoader): https://arxiv.org/abs/1911.12237 """ + DATASET_NAME = "samsum" + def __init__(self): super(SAMSumLoader, self).__init__() - def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: + def load(self, paths: Optional[Path] = None) -> DataBundle: if paths is None: - paths = self.download("samsum") + paths = self.download() _paths = {} if paths: @@ -253,12 +268,14 @@ class WikiHowLoader(SumLoader): https://arxiv.org/abs/1810.09305 """ + DATASET_NAME = "wikihow" + def __init__(self): super(WikiHowLoader, self).__init__() - def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: + def load(self, paths: Optional[Path] = None) -> DataBundle: if paths is None: - paths = self.download("wikihow") + paths = self.download() _paths = {} if paths: @@ -285,12 +302,14 @@ class XsumLoader(SumLoader): https://arxiv.org/abs/1808.08745 """ + DATASET_NAME = "xsum" + def __init__(self): super(XsumLoader, self).__init__() - def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: + def load(self, paths: Optional[Path] = None) -> DataBundle: if paths is None: - paths = self.download("xsum") + paths = self.download() _paths = {} if paths: @@ -317,6 +336,8 @@ class RedditTIFULoader(SumLoader): https://arxiv.org/abs/1811.00783 """ + DATASET_NAME = "reddit tifu" + def __init__(self, tag, valid_ratio=0.05, test_ratio=0.05): super(RedditTIFULoader, self).__init__() self.valid_ratio = valid_ratio @@ -324,9 +345,9 @@ class RedditTIFULoader(SumLoader): assert tag in ["long", "short"], "tag not valid (neither long nor short)!" self.tag = tag - def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: + def load(self, paths: Optional[Path] = None) -> DataBundle: if paths is None: - paths = self.download("reddit tifu") + paths = self.download() _paths = {} if paths: @@ -361,6 +382,8 @@ class AMILoader(SumLoader): http://groups.inf.ed.ac.uk/ami/download/ """ + DATASET_NAME = "ami" + def __init__(self, valid_ratio=0.05, test_ratio=0.05): # AMI 没有 label fields = { @@ -372,9 +395,9 @@ class AMILoader(SumLoader): self.valid_ratio = valid_ratio self.test_ratio = test_ratio - def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: + def load(self, paths: Optional[Path] = None) -> DataBundle: if paths is None: - paths = self.download("ami") + paths = self.download() _paths = {} if paths: @@ -409,6 +432,8 @@ class ICSILoader(SumLoader): http://groups.inf.ed.ac.uk/ami/icsi/ """ + DATASET_NAME = "icsi" + def __init__(self, valid_ratio=0.05, test_ratio=0.05): # ICSI 没有 label fields = { @@ -419,9 +444,9 @@ class ICSILoader(SumLoader): self.valid_ratio = valid_ratio self.test_ratio = test_ratio - def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: + def load(self, paths: Optional[Path] = None) -> DataBundle: if paths is None: - paths = self.download("icsi") + paths = self.download() _paths = {} if paths: -- Gitee