From 768aec498666c30464cc1a29a58f8ce6f078a8f7 Mon Sep 17 00:00:00 2001 From: Zhangyue Yin <40982202+yinzhangyue@users.noreply.github.com> Date: Thu, 4 Mar 2021 21:02:13 +0800 Subject: [PATCH 1/4] DataLoader Readme DataLoader Readme --- fastSum/Dataloader/Readme_en.md | 6 ++++++ fastSum/Dataloader/Readme_zh.md | 6 ++++++ fastSum/Dataloader/example.py | 6 ++++++ 3 files changed, 18 insertions(+) create mode 100644 fastSum/Dataloader/Readme_en.md create mode 100644 fastSum/Dataloader/Readme_zh.md create mode 100644 fastSum/Dataloader/example.py diff --git a/fastSum/Dataloader/Readme_en.md b/fastSum/Dataloader/Readme_en.md new file mode 100644 index 0000000..163bb03 --- /dev/null +++ b/fastSum/Dataloader/Readme_en.md @@ -0,0 +1,6 @@ +# Tips + +1. please install the latest FastNLP: pip install git+https://gitee.com/fastnlp/fastNLP@dev +2. Specify FASTNLP_CACHE_DIR in the system environment variable, which will be the data set download location. +3. example.py is a simple example. + diff --git a/fastSum/Dataloader/Readme_zh.md b/fastSum/Dataloader/Readme_zh.md new file mode 100644 index 0000000..88722a5 --- /dev/null +++ b/fastSum/Dataloader/Readme_zh.md @@ -0,0 +1,6 @@ +# 使用提醒 + +1. 使用前请安装最新的fastNLP,安装方式:pip install git+https://gitee.com/fastnlp/fastNLP@dev +2. 在系统环境变量中指定FASTNLP_CACHE_DIR的位置,为数据集下载位置 +3. 使用方法可参照example.py + diff --git a/fastSum/Dataloader/example.py b/fastSum/Dataloader/example.py new file mode 100644 index 0000000..656f5ac --- /dev/null +++ b/fastSum/Dataloader/example.py @@ -0,0 +1,6 @@ +from summarizationLoader import ArxivLoader + +if __name__ == '__main__': + ArxivLoader().download() + data = ArxivLoader().load() + print(data) \ No newline at end of file -- Gitee From 0f14cf0df430fd49d03a8b8a2f00ea79c78433df Mon Sep 17 00:00:00 2001 From: Zhangyue Yin <40982202+yinzhangyue@users.noreply.github.com> Date: Thu, 4 Mar 2021 23:04:23 +0800 Subject: [PATCH 2/4] DataLoader Update DataLoader Update --- .../summarizationLoader.cpython-37.pyc | Bin 0 -> 17107 bytes fastSum/Dataloader/example.py | 7 +- fastSum/Dataloader/summarizationLoader.py | 162 ++++++++++++------ 3 files changed, 120 insertions(+), 49 deletions(-) create mode 100644 fastSum/Dataloader/__pycache__/summarizationLoader.cpython-37.pyc diff --git a/fastSum/Dataloader/__pycache__/summarizationLoader.cpython-37.pyc b/fastSum/Dataloader/__pycache__/summarizationLoader.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..558ec94158cfa67c94c03f44931251862ec6f9c4 GIT binary patch literal 17107 zcmeI4?Q;}Ie#hr&cV{0YEdm4x3=77_6)N0QDtCvkP^IeXT&Y|hURB z??-rmtwVU7(U0(Igx9kT2yYPKH3$!~jR zBD__E*CV`*J&N$7BD?|N?d&mx9~0q0gm~Vx27vV?P z0b?^8zAdwTZ2yf&ehX91O4%nrqPET|t~O?t99RF+v}0B)MQc1OyV~`0?POKA_oCqxO2yJ8qcDxz+;o8=7_PPPTjQ3S6sb%k z*JGonPvz?7bUi-Gub8h1Z}O;VS!ZiyFaFs0_}FRT)lVHe>s{P?x@I}%6XV8=Em9KC z*Cxsa6aKz424kj^Id|&Bi(cZ{(bGY0$@Auvd9pg=<@b{96_QQ5YO!p(ks{T>jZL6@ zB)O4NrNYXtW|nNTo+yhlD=5>APK)yDN$!>9n3FX(Y8T6NS8N7lT&m8v+7$##)ULr1 z@jdpJ;W4DV%%r@+WXP@pxlI}oBWlEq*rc*S+5joOEx#pA%1mRC8$EfQrH~%qAQ=f3 zy{)hq)6siLlG7VfKJ}KQN%K z{pH=8Km6#sZ(5=+?|tXSy|;gK@2}n$7gC2uCr!g*c1iK8OTQXU6{o!iP6HL6()j?B znc_&7WmHw2l^w-V8`5>tG&VihM0 zD=YK02xTKYMv>8Tqi3Eye$I{9wP}O%H58{83TDN03I!5U+ps3RJ7hOrD3q+CZ5Imf zNguy(?CA?H+6K2TTs158OBcqfrCQmjIQE5SP3L57;=*LncF;cIFG;c24i~)Yne`%x zzSn@c>8pGrO5cXl7JQO*MUfRbq15}sX3n)#EkaTc3E_rQk>LeZb+BW zG1rw@#nBoHx(1zgQ=L_v$mQsrSVNtY`H1lA*CgvL&t;nak#Qng)8z#(ed zlqHHXxyVKn4dM1PW3q_h)(R4p*paB#oaq{PV)FS7rL*cxg<6DU-cG}jw~Zr{E314Q5bZi^ESnF+8+yWUsiK z_o6nTV3UATThnRx!*L`M>g&oKWlr)LgDzD0Akz3ja;PQIyfmnSRu;`Jnwk<%;0y-T znwEA5NAftvWCj=Ut#HCoOyi4sS5)Ok@c1B1zb)O7<}lh!X%-`lf8CKUD|gg}jInh? z_%-21X!Ok~bFzKXi85UIUFlurx_k|z6xaKy4TY(P@rZtmHm4va(h?IDF|or^Lz$H; zy-tksVY-)-5CNJK^oD##pOeO<>lo8cyrHt>O?g&pXz2XZczvBWuR6SFR&s*DxyzW| zt@^GG2^rJQRH{y9vWhuBV^%V5MBE(Y^jgs}S!+(N%8b|Q1Acaf?X)FgO0b>!V32Du zBks>&U>hZ;%CF|a{-?&KhF^`-rY<$5hV=593wZchg&Mwr25rdSP+rBOiQz@-kpjsr z(52*t_zv8}cf;{zqKVg!6NeerJEdxQ`k4!qKaJ<8RNvkz69OuSt!k-g+0W#fH=r2Z zXQZ7?h{4XYK~ZkRw24xLZdbcpH7o9_6JXtV)j2^ec%1Vpcl!|>9Sfe>i;I|!QMLs| zFdxzb3rTdNV#Q?hXDG)c611`wtI#58c2LQF#L z$6r+GRR-X%r{_N~Q8CGfaXa5f&VF(ZkR#wngUBP`Jmft^35zq1hUQ_MHcb_Jlo*#% z#egj#B0N4r!ihDdPq?q506`ZK*)76}6{H8Z348DG$k2h@(34M%Jn3MZHHwCo^{ zZfNMqefvE;_0aRmaOj|4ntJJ6dfD|Xu38#TwzCqR7C8T%3ib|msRd9x@s{Vx! z(FJ(X*n5-e)&?(%?Gn8D#OpVD`{v{8I?56V2@{qJU_(#khK3IwI3!AMeMk3z96O08 zS(;-%rq=d3wmy7juQ47Tl5G@kMCpa-A|TuI35gvz){9>UTa&Q|KTCmdhh$qBG}4yO6kC?Ag2z%S{G!uVrlhXfi!?t zkF`m|N zdZS@Qyv>SH?LIxD7Sz8BjG{sIG0t{IQEZnO)#K|`_>g)7B?(ga(aS`s5qh0#$~cTDO4s&&B)AXV$fLhqgX`XH6Q z9JVE|n78gn0)e0xSm!-eqZLG|Jyg9N5be&N1lzWK{RJ;VQOe3H#BnafR8EquKj?RdJeZPjj1mMQVC40Za(|N!>9pL#9IVV zI~R}dKqoJk4vVO@UREoXOe~NAkm=DD8L&uUYEj*W<@+3XhU&HrqP!TENK2tbPoZyK z-Mhde8dS#@@Q7j;=MjC5ov(<+5~1d19I9B!P!%f%%}wYQLvypKij{KHp3W+%tomH~ zePk-Q)S}J0TzJA>Z)oV?i1$&jO~>`#cfL)>d!H`_N*eU5`ln6mA??ap&^MF{&?tmU zN}=*p)ZXnxn9zA@K#8jm6?%{=)U8OoKsqb`(7lpc!7;l+f6~lBc89YPR zAf9K$W!bISn5Xu28bu2dx=C?Ifs+=eo1nFV@9!P8p%_ZKot~bJiXyU}!XZCLw?pe( z;%DG2t}|wnOppeLB=2efmU&_0xP$>I~6JxR;hkKw{rU}I|J&e`fMLdhJD6Wd*Ostbw>}EukTEhC2 z#{b&U8L6T{iHxplrf&4F4^`8WLJ=)`SmeOy-hU0#QX(xO(h?%=*I`mScYNa?yzPv?qO+hVc4b$-kOp4Su2OWbJLWKdKS zS|89COWM;Gw2VpPbGS9I&y5I?;^;C-=y`an(xkUP_jKpt$jDct6O!lqyQSrJKMEn= zBk842R+s>?3uJbwX`pwWcH4Zd5HWs6kvp_!y#|D~e{(?nrV1m!OO-!@?1LF1Du$qr znL&})sJ(FMFO5G^_oY8g%p0WTMrZgXW9_1k4rR}AWGL#dd8* ziI1K>HGe8WR5zv&==EOY)!MY3Gb@uhSYZ@Px!TlT7*y;Hw5)x#-Xxy5IfW1b>1dlm zNK2Cx)L~JeKy5YD^d;4F8LSl32~8?6^MN9(Hr_NPKqrP~1yD&6+Q4~aC#Hp_EC$FA zY46(+c*@v9b5STNk%r@+Q>P{1kyI#=D`^G5?+;t{en`ItIW10l{vU#ip0e=Ig};UJ ziz*Ap7?B)iAz-skdmo?)|$k z&7(YMO?im4e+tt)%IhJ?6N2PDEhLBWlpuLBBzX!XPlqI5H4f7s*Z_F|9Spq-pq*Ed z#$TikB?Ud`!F`$h=gE1A9JKKNMlk;)$(Rr3ZB5x)ipV`n5P6hZfM6gB&XP}p*KNh zYCe0fp&+r0r?_Q?FIlbm;O#>-mKVGiP~vX@-h0$#E#Re<_UdN4{y%xyXF^^fsy=!* z3hPRN;f3tgDDnRwd{+Uy0el~Ak*^g<-~aqf=-W5%A%WkQX{eIZ1^f~Q{&^mL6I{T7 z^ytM*eh(gTxDg?J=U`V8pE8VTSR9%cOEEu{7LTN$O!JrhE9BYG@IFD|M;z&n3?o!w zZ`F`!OWcTH`-^RH7^}9%&cSB9A)m#Y3f!ZG*C=h*$hTzsm!bp>o9|yTxuI#;vrkF< zb*4^9_8(M`3*KaiP4+FhzT1h0IknY+Hr&`|tABKBHrBj`>7tIxoPuq(mp(bGH`5z! zwF$2onT=l=YQ)B*mv_x3oY4conW*Bsg$ zi{X~?TGM3IpL#8u7#Ax&SI&I|A;8;A?76)+JoHYeX0gnKLFQ?^Sh;9qb`9+r&gNX1 z*CRD&@`(|$*)l4nDl;n=>nYm$!_JdzP;t!+j32Vd(v5LpkxfY{Tw$GxoaM%IBtwAbuDv( z(>^V~itMMypL;%MR>jA6SZ9-MUVfk@zST617Ml6~AhA^CMy^EvDZsPyUK=DJH|m5! z%`q*zX-X{Yh50VzAq-(XW4jh-)BgkY|1`OCUa;?vKmgwYcy|h+wSUGdTaVI{OwBb6u&#t+tW|u L?{hGn5P$yzppvMd literal 0 HcmV?d00001 diff --git a/fastSum/Dataloader/example.py b/fastSum/Dataloader/example.py index 656f5ac..7c0d190 100644 --- a/fastSum/Dataloader/example.py +++ b/fastSum/Dataloader/example.py @@ -1,6 +1,11 @@ -from summarizationLoader import ArxivLoader +from fastNLP.io.file_utils import get_cache_path +from summarizationLoader import ArxivLoader if __name__ == '__main__': + + # 请设置fastNLP默认cache的存放路径FASTNLP_CACHE_DIR, get_cache_path会获取设置下载的数据位置 + # 详细可参考: https://gitee.com/fastnlp/fastNLP/blob/7b4e099c5267efb6a4a88b9d789a0940be05bb56/fastNLP/io/file_utils.py#L228 + print(f'下载的数据位置: {get_cache_path()}') ArxivLoader().download() data = ArxivLoader().load() print(data) \ No newline at end of file diff --git a/fastSum/Dataloader/summarizationLoader.py b/fastSum/Dataloader/summarizationLoader.py index 65d6976..c813399 100644 --- a/fastSum/Dataloader/summarizationLoader.py +++ b/fastSum/Dataloader/summarizationLoader.py @@ -8,7 +8,6 @@ from fastNLP.io.data_bundle import DataBundle from fastNLP.core.const import Const from fastNLP.io.file_utils import get_cache_path, _get_dataset_url, cached_path - DATASET_DIR = { # Summarization 'ami': "AMI.zip", @@ -47,7 +46,9 @@ class SumLoader(JsonLoader): def download(self): default_cache_path = get_cache_path() url = _get_dataset_url(self.DATASET_NAME, DATASET_DIR) - output_dir = cached_path(url_or_filename=url, cache_dir=default_cache_path, name='dataset') + output_dir = cached_path(url_or_filename=url, + cache_dir=default_cache_path, + name='dataset') # https://gitee.com/fastnlp/fastNLP/blob/7b4e099c5267efb6a4a88b9d789a0940be05bb56/fastNLP/io/file_utils.py#L201 # 如果只有一个文件, get_filepath 返回 filepath + filename # os.path.dirname 反向处理 @@ -76,9 +77,12 @@ class CNNDMLoader(SumLoader): _paths = {} if paths: if os.path.isdir(paths): - if not os.path.isfile(os.path.join(paths, 'CNNDM.train.label.jsonl')): - raise FileNotFoundError(f"CNNDM.train.label.jsonl is not found in {paths}") - _paths['train'] = os.path.join(paths, 'CNNDM.train.label.jsonl') + if not os.path.isfile( + os.path.join(paths, 'CNNDM.train.label.jsonl')): + raise FileNotFoundError( + f"CNNDM.train.label.jsonl is not found in {paths}") + _paths['train'] = os.path.join(paths, + 'CNNDM.train.label.jsonl') _paths['dev'] = os.path.join(paths, 'CNNDM.valid.label.jsonl') _paths['test'] = os.path.join(paths, 'CNNDM.test.label.jsonl') paths = _paths @@ -110,9 +114,12 @@ class ArxivLoader(SumLoader): _paths = {} if paths: if os.path.isdir(paths): - if not os.path.isfile(os.path.join(paths, 'arxiv.train.label.jsonl')): - raise FileNotFoundError(f"arxiv.train.label.jsonl is not found in {paths}") - _paths['train'] = os.path.join(paths, 'arxiv.train.label.jsonl') + if not os.path.isfile( + os.path.join(paths, 'arxiv.train.label.jsonl')): + raise FileNotFoundError( + f"arxiv.train.label.jsonl is not found in {paths}") + _paths['train'] = os.path.join(paths, + 'arxiv.train.label.jsonl') _paths['dev'] = os.path.join(paths, 'arxiv.valid.label.jsonl') _paths['test'] = os.path.join(paths, 'arxiv.test.label.jsonl') paths = _paths @@ -144,11 +151,17 @@ class BillSumLoader(SumLoader): _paths = {} if paths: if os.path.isdir(paths): - if not os.path.isfile(os.path.join(paths, 'billsum_us.train.label.jsonl')): - raise FileNotFoundError(f"billsum_us.train.label.jsonl is not found in {paths}") - _paths['train'] = os.path.join(paths, 'billsum_us.train.label.jsonl') - _paths['dev'] = os.path.join(paths, 'billsum_ca.valid.label.jsonl') - _paths['test'] = os.path.join(paths, 'billsum_us.test.label.jsonl') + if not os.path.isfile( + os.path.join(paths, 'billsum_us.train.label.jsonl')): + raise FileNotFoundError( + f"billsum_us.train.label.jsonl is not found in {paths}" + ) + _paths['train'] = os.path.join(paths, + 'billsum_us.train.label.jsonl') + _paths['dev'] = os.path.join(paths, + 'billsum_ca.valid.label.jsonl') + _paths['test'] = os.path.join(paths, + 'billsum_us.test.label.jsonl') paths = _paths else: raise NotADirectoryError(f"{paths} is not a valid directory.") @@ -178,11 +191,16 @@ class MultiNewsLoader(SumLoader): _paths = {} if paths: if os.path.isdir(paths): - if not os.path.isfile(os.path.join(paths, 'multinews.train.label.jsonl')): - raise FileNotFoundError(f"multinews.train.label.jsonl is not found in {paths}") - _paths['train'] = os.path.join(paths, 'multinews.train.label.jsonl') - _paths['dev'] = os.path.join(paths, 'multinews.valid.label.jsonl') - _paths['test'] = os.path.join(paths, 'multinews.test.label.jsonl') + if not os.path.isfile( + os.path.join(paths, 'multinews.train.label.jsonl')): + raise FileNotFoundError( + f"multinews.train.label.jsonl is not found in {paths}") + _paths['train'] = os.path.join(paths, + 'multinews.train.label.jsonl') + _paths['dev'] = os.path.join(paths, + 'multinews.valid.label.jsonl') + _paths['test'] = os.path.join(paths, + 'multinews.test.label.jsonl') paths = _paths else: raise NotADirectoryError(f"{paths} is not a valid directory.") @@ -212,9 +230,12 @@ class PubmedLoader(SumLoader): _paths = {} if paths: if os.path.isdir(paths): - if not os.path.isfile(os.path.join(paths, 'pubmed.train.label.jsonl')): - raise FileNotFoundError(f"pubmed.train.label.jsonl is not found in {paths}") - _paths['train'] = os.path.join(paths, 'pubmed.train.label.jsonl') + if not os.path.isfile( + os.path.join(paths, 'pubmed.train.label.jsonl')): + raise FileNotFoundError( + f"pubmed.train.label.jsonl is not found in {paths}") + _paths['train'] = os.path.join(paths, + 'pubmed.train.label.jsonl') _paths['dev'] = os.path.join(paths, 'pubmed.valid.label.jsonl') _paths['test'] = os.path.join(paths, 'pubmed.test.label.jsonl') paths = _paths @@ -246,9 +267,12 @@ class SAMSumLoader(SumLoader): _paths = {} if paths: if os.path.isdir(paths): - if not os.path.isfile(os.path.join(paths, 'SAMSum.train.label.jsonl')): - raise FileNotFoundError(f"SAMSum.train.label.jsonl is not found in {paths}") - _paths['train'] = os.path.join(paths, 'SAMSum.train.label.jsonl') + if not os.path.isfile( + os.path.join(paths, 'SAMSum.train.label.jsonl')): + raise FileNotFoundError( + f"SAMSum.train.label.jsonl is not found in {paths}") + _paths['train'] = os.path.join(paths, + 'SAMSum.train.label.jsonl') _paths['dev'] = os.path.join(paths, 'SAMSum.valid.label.jsonl') _paths['test'] = os.path.join(paths, 'SAMSum.test.label.jsonl') paths = _paths @@ -280,11 +304,15 @@ class WikiHowLoader(SumLoader): _paths = {} if paths: if os.path.isdir(paths): - if not os.path.isfile(os.path.join(paths, 'wikihow.train.label.jsonl')): - raise FileNotFoundError(f"wikihow.train.label.jsonl is not found in {paths}") - _paths['train'] = os.path.join(paths, 'wikihow.train.label.jsonl') + if not os.path.isfile( + os.path.join(paths, 'wikihow.train.label.jsonl')): + raise FileNotFoundError( + f"wikihow.train.label.jsonl is not found in {paths}") + _paths['train'] = os.path.join(paths, + 'wikihow.train.label.jsonl') _paths['dev'] = os.path.join(paths, 'wikihow.val.label.jsonl') - _paths['test'] = os.path.join(paths, 'wikihow.test.label.jsonl') + _paths['test'] = os.path.join(paths, + 'wikihow.test.label.jsonl') paths = _paths else: raise NotADirectoryError(f"{paths} is not a valid directory.") @@ -314,8 +342,10 @@ class XsumLoader(SumLoader): _paths = {} if paths: if os.path.isdir(paths): - if not os.path.isfile(os.path.join(paths, 'xsum.train.label.jsonl')): - raise FileNotFoundError(f"xsum.train.label.jsonl is not found in {paths}") + if not os.path.isfile( + os.path.join(paths, 'xsum.train.label.jsonl')): + raise FileNotFoundError( + f"xsum.train.label.jsonl is not found in {paths}") _paths['train'] = os.path.join(paths, 'xsum.train.label.jsonl') _paths['dev'] = os.path.join(paths, 'xsum.valid.label.jsonl') _paths['test'] = os.path.join(paths, 'xsum.test.label.jsonl') @@ -342,7 +372,8 @@ class RedditTIFULoader(SumLoader): super(RedditTIFULoader, self).__init__() self.valid_ratio = valid_ratio self.test_ratio = test_ratio - assert tag in ["long", "short"], "tag not valid (neither long nor short)!" + assert tag in ["long", + "short"], "tag not valid (neither long nor short)!" self.tag = tag def load(self, paths: Optional[Path] = None) -> DataBundle: @@ -352,14 +383,25 @@ class RedditTIFULoader(SumLoader): _paths = {} if paths: if os.path.isdir(paths): - if not os.path.isfile(os.path.join(paths, f"tifu_{self.tag}.all.label.jsonl")): - raise FileNotFoundError(f"tifu_{self.tag}.all.label.jsonl is not found in {paths}") - - _split_set(f"tifu_{self.tag}.all.label", paths, split_name1="middev", split_name2="train", + if not os.path.isfile( + os.path.join(paths, + f"tifu_{self.tag}.all.label.jsonl")): + raise FileNotFoundError( + f"tifu_{self.tag}.all.label.jsonl is not found in {paths}" + ) + + _split_set(f"tifu_{self.tag}.all.label", + paths, + split_name1="middev", + split_name2="train", ratio=self.valid_ratio + self.test_ratio) if self.valid_ratio + self.test_ratio > 0: - _split_set('middev', paths, split_name1="test", split_name2="dev", - ratio=self.test_ratio / (self.valid_ratio + self.test_ratio)) + _split_set('middev', + paths, + split_name1="test", + split_name2="dev", + ratio=self.test_ratio / + (self.valid_ratio + self.test_ratio)) _paths['train'] = os.path.join(paths, 'train.jsonl') if self.valid_ratio > 0: _paths['dev'] = os.path.join(paths, 'dev.jsonl') @@ -403,13 +445,21 @@ class AMILoader(SumLoader): if paths: if os.path.isdir(paths): if not os.path.isfile(os.path.join(paths, 'AMI.jsonl')): - raise FileNotFoundError(f"AMI.jsonl is not found in {paths}") + raise FileNotFoundError( + f"AMI.jsonl is not found in {paths}") - _split_set('AMI', paths, split_name1="middev", split_name2="train", + _split_set('AMI', + paths, + split_name1="middev", + split_name2="train", ratio=self.valid_ratio + self.test_ratio) if self.valid_ratio + self.test_ratio > 0: - _split_set('middev', paths, split_name1="test", split_name2="dev", - ratio=self.test_ratio / (self.valid_ratio + self.test_ratio)) + _split_set('middev', + paths, + split_name1="test", + split_name2="dev", + ratio=self.test_ratio / + (self.valid_ratio + self.test_ratio)) _paths['train'] = os.path.join(paths, 'train.jsonl') if self.valid_ratio > 0: _paths['dev'] = os.path.join(paths, 'dev.jsonl') @@ -452,13 +502,21 @@ class ICSILoader(SumLoader): if paths: if os.path.isdir(paths): if not os.path.isfile(os.path.join(paths, 'ICSI.jsonl')): - raise FileNotFoundError(f"ICSI.jsonl is not found in {paths}") + raise FileNotFoundError( + f"ICSI.jsonl is not found in {paths}") - _split_set('ICSI', paths, split_name1="middev", split_name2="train", + _split_set('ICSI', + paths, + split_name1="middev", + split_name2="train", ratio=self.valid_ratio + self.test_ratio) if self.valid_ratio + self.test_ratio > 0: - _split_set('middev', paths, split_name1="test", split_name2="dev", - ratio=self.test_ratio / (self.valid_ratio + self.test_ratio)) + _split_set('middev', + paths, + split_name1="test", + split_name2="dev", + ratio=self.test_ratio / + (self.valid_ratio + self.test_ratio)) _paths['train'] = os.path.join(paths, 'train.jsonl') if self.valid_ratio > 0: _paths['dev'] = os.path.join(paths, 'dev.jsonl') @@ -473,7 +531,13 @@ class ICSILoader(SumLoader): return data_bundle -def _split_set(dataset_name, data_dir, split_name1="dev", split_name2="train", ratio=0.0, suffix='jsonl', keep_orig: bool = True): +def _split_set(dataset_name, + data_dir, + split_name1="dev", + split_name2="train", + ratio=0.0, + suffix='jsonl', + keep_orig: bool = True): if ratio == 0: os.renames(os.path.join(data_dir, f'{dataset_name}.{suffix}'), os.path.join(data_dir, f'{split_name2}.{suffix}')) @@ -494,11 +558,13 @@ def _split_set(dataset_name, data_dir, split_name1="dev", split_name2="train", r if keep_orig: assert split_name1 != dataset_name and split_name2 != dataset_name else: - os.remove(os.path.join(data_dir, f'{dataset_name}.{suffix}')) + os.remove( + os.path.join(data_dir, f'{dataset_name}.{suffix}')) os.renames(os.path.join(data_dir, f'middle_file.{suffix}'), os.path.join(data_dir, f'{split_name2}.{suffix}')) finally: - if os.path.exists(os.path.join(data_dir, f'middle_file.{suffix}')): + if os.path.exists( + os.path.join(data_dir, f'middle_file.{suffix}')): os.remove(os.path.join(data_dir, f'middle_file.{suffix}')) return data_dir -- Gitee From b738843d4b1292328f5f727449b4dc67ec6b3415 Mon Sep 17 00:00:00 2001 From: Zhangyue Yin <40982202+yinzhangyue@users.noreply.github.com> Date: Thu, 4 Mar 2021 23:06:07 +0800 Subject: [PATCH 3/4] Delete summarizationLoader.cpython-37.pyc --- .../summarizationLoader.cpython-37.pyc | Bin 17107 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 fastSum/Dataloader/__pycache__/summarizationLoader.cpython-37.pyc diff --git a/fastSum/Dataloader/__pycache__/summarizationLoader.cpython-37.pyc b/fastSum/Dataloader/__pycache__/summarizationLoader.cpython-37.pyc deleted file mode 100644 index 558ec94158cfa67c94c03f44931251862ec6f9c4..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 17107 zcmeI4?Q;}Ie#hr&cV{0YEdm4x3=77_6)N0QDtCvkP^IeXT&Y|hURB z??-rmtwVU7(U0(Igx9kT2yYPKH3$!~jR zBD__E*CV`*J&N$7BD?|N?d&mx9~0q0gm~Vx27vV?P z0b?^8zAdwTZ2yf&ehX91O4%nrqPET|t~O?t99RF+v}0B)MQc1OyV~`0?POKA_oCqxO2yJ8qcDxz+;o8=7_PPPTjQ3S6sb%k z*JGonPvz?7bUi-Gub8h1Z}O;VS!ZiyFaFs0_}FRT)lVHe>s{P?x@I}%6XV8=Em9KC z*Cxsa6aKz424kj^Id|&Bi(cZ{(bGY0$@Auvd9pg=<@b{96_QQ5YO!p(ks{T>jZL6@ zB)O4NrNYXtW|nNTo+yhlD=5>APK)yDN$!>9n3FX(Y8T6NS8N7lT&m8v+7$##)ULr1 z@jdpJ;W4DV%%r@+WXP@pxlI}oBWlEq*rc*S+5joOEx#pA%1mRC8$EfQrH~%qAQ=f3 zy{)hq)6siLlG7VfKJ}KQN%K z{pH=8Km6#sZ(5=+?|tXSy|;gK@2}n$7gC2uCr!g*c1iK8OTQXU6{o!iP6HL6()j?B znc_&7WmHw2l^w-V8`5>tG&VihM0 zD=YK02xTKYMv>8Tqi3Eye$I{9wP}O%H58{83TDN03I!5U+ps3RJ7hOrD3q+CZ5Imf zNguy(?CA?H+6K2TTs158OBcqfrCQmjIQE5SP3L57;=*LncF;cIFG;c24i~)Yne`%x zzSn@c>8pGrO5cXl7JQO*MUfRbq15}sX3n)#EkaTc3E_rQk>LeZb+BW zG1rw@#nBoHx(1zgQ=L_v$mQsrSVNtY`H1lA*CgvL&t;nak#Qng)8z#(ed zlqHHXxyVKn4dM1PW3q_h)(R4p*paB#oaq{PV)FS7rL*cxg<6DU-cG}jw~Zr{E314Q5bZi^ESnF+8+yWUsiK z_o6nTV3UATThnRx!*L`M>g&oKWlr)LgDzD0Akz3ja;PQIyfmnSRu;`Jnwk<%;0y-T znwEA5NAftvWCj=Ut#HCoOyi4sS5)Ok@c1B1zb)O7<}lh!X%-`lf8CKUD|gg}jInh? z_%-21X!Ok~bFzKXi85UIUFlurx_k|z6xaKy4TY(P@rZtmHm4va(h?IDF|or^Lz$H; zy-tksVY-)-5CNJK^oD##pOeO<>lo8cyrHt>O?g&pXz2XZczvBWuR6SFR&s*DxyzW| zt@^GG2^rJQRH{y9vWhuBV^%V5MBE(Y^jgs}S!+(N%8b|Q1Acaf?X)FgO0b>!V32Du zBks>&U>hZ;%CF|a{-?&KhF^`-rY<$5hV=593wZchg&Mwr25rdSP+rBOiQz@-kpjsr z(52*t_zv8}cf;{zqKVg!6NeerJEdxQ`k4!qKaJ<8RNvkz69OuSt!k-g+0W#fH=r2Z zXQZ7?h{4XYK~ZkRw24xLZdbcpH7o9_6JXtV)j2^ec%1Vpcl!|>9Sfe>i;I|!QMLs| zFdxzb3rTdNV#Q?hXDG)c611`wtI#58c2LQF#L z$6r+GRR-X%r{_N~Q8CGfaXa5f&VF(ZkR#wngUBP`Jmft^35zq1hUQ_MHcb_Jlo*#% z#egj#B0N4r!ihDdPq?q506`ZK*)76}6{H8Z348DG$k2h@(34M%Jn3MZHHwCo^{ zZfNMqefvE;_0aRmaOj|4ntJJ6dfD|Xu38#TwzCqR7C8T%3ib|msRd9x@s{Vx! z(FJ(X*n5-e)&?(%?Gn8D#OpVD`{v{8I?56V2@{qJU_(#khK3IwI3!AMeMk3z96O08 zS(;-%rq=d3wmy7juQ47Tl5G@kMCpa-A|TuI35gvz){9>UTa&Q|KTCmdhh$qBG}4yO6kC?Ag2z%S{G!uVrlhXfi!?t zkF`m|N zdZS@Qyv>SH?LIxD7Sz8BjG{sIG0t{IQEZnO)#K|`_>g)7B?(ga(aS`s5qh0#$~cTDO4s&&B)AXV$fLhqgX`XH6Q z9JVE|n78gn0)e0xSm!-eqZLG|Jyg9N5be&N1lzWK{RJ;VQOe3H#BnafR8EquKj?RdJeZPjj1mMQVC40Za(|N!>9pL#9IVV zI~R}dKqoJk4vVO@UREoXOe~NAkm=DD8L&uUYEj*W<@+3XhU&HrqP!TENK2tbPoZyK z-Mhde8dS#@@Q7j;=MjC5ov(<+5~1d19I9B!P!%f%%}wYQLvypKij{KHp3W+%tomH~ zePk-Q)S}J0TzJA>Z)oV?i1$&jO~>`#cfL)>d!H`_N*eU5`ln6mA??ap&^MF{&?tmU zN}=*p)ZXnxn9zA@K#8jm6?%{=)U8OoKsqb`(7lpc!7;l+f6~lBc89YPR zAf9K$W!bISn5Xu28bu2dx=C?Ifs+=eo1nFV@9!P8p%_ZKot~bJiXyU}!XZCLw?pe( z;%DG2t}|wnOppeLB=2efmU&_0xP$>I~6JxR;hkKw{rU}I|J&e`fMLdhJD6Wd*Ostbw>}EukTEhC2 z#{b&U8L6T{iHxplrf&4F4^`8WLJ=)`SmeOy-hU0#QX(xO(h?%=*I`mScYNa?yzPv?qO+hVc4b$-kOp4Su2OWbJLWKdKS zS|89COWM;Gw2VpPbGS9I&y5I?;^;C-=y`an(xkUP_jKpt$jDct6O!lqyQSrJKMEn= zBk842R+s>?3uJbwX`pwWcH4Zd5HWs6kvp_!y#|D~e{(?nrV1m!OO-!@?1LF1Du$qr znL&})sJ(FMFO5G^_oY8g%p0WTMrZgXW9_1k4rR}AWGL#dd8* ziI1K>HGe8WR5zv&==EOY)!MY3Gb@uhSYZ@Px!TlT7*y;Hw5)x#-Xxy5IfW1b>1dlm zNK2Cx)L~JeKy5YD^d;4F8LSl32~8?6^MN9(Hr_NPKqrP~1yD&6+Q4~aC#Hp_EC$FA zY46(+c*@v9b5STNk%r@+Q>P{1kyI#=D`^G5?+;t{en`ItIW10l{vU#ip0e=Ig};UJ ziz*Ap7?B)iAz-skdmo?)|$k z&7(YMO?im4e+tt)%IhJ?6N2PDEhLBWlpuLBBzX!XPlqI5H4f7s*Z_F|9Spq-pq*Ed z#$TikB?Ud`!F`$h=gE1A9JKKNMlk;)$(Rr3ZB5x)ipV`n5P6hZfM6gB&XP}p*KNh zYCe0fp&+r0r?_Q?FIlbm;O#>-mKVGiP~vX@-h0$#E#Re<_UdN4{y%xyXF^^fsy=!* z3hPRN;f3tgDDnRwd{+Uy0el~Ak*^g<-~aqf=-W5%A%WkQX{eIZ1^f~Q{&^mL6I{T7 z^ytM*eh(gTxDg?J=U`V8pE8VTSR9%cOEEu{7LTN$O!JrhE9BYG@IFD|M;z&n3?o!w zZ`F`!OWcTH`-^RH7^}9%&cSB9A)m#Y3f!ZG*C=h*$hTzsm!bp>o9|yTxuI#;vrkF< zb*4^9_8(M`3*KaiP4+FhzT1h0IknY+Hr&`|tABKBHrBj`>7tIxoPuq(mp(bGH`5z! zwF$2onT=l=YQ)B*mv_x3oY4conW*Bsg$ zi{X~?TGM3IpL#8u7#Ax&SI&I|A;8;A?76)+JoHYeX0gnKLFQ?^Sh;9qb`9+r&gNX1 z*CRD&@`(|$*)l4nDl;n=>nYm$!_JdzP;t!+j32Vd(v5LpkxfY{Tw$GxoaM%IBtwAbuDv( z(>^V~itMMypL;%MR>jA6SZ9-MUVfk@zST617Ml6~AhA^CMy^EvDZsPyUK=DJH|m5! z%`q*zX-X{Yh50VzAq-(XW4jh-)BgkY|1`OCUa;?vKmgwYcy|h+wSUGdTaVI{OwBb6u&#t+tW|u L?{hGn5P$yzppvMd -- Gitee From 1a5d7cbd02b8d493b72c1d9b4b154c3c62a4c24c Mon Sep 17 00:00:00 2001 From: Zhangyue Yin <40982202+yinzhangyue@users.noreply.github.com> Date: Thu, 4 Mar 2021 23:24:28 +0800 Subject: [PATCH 4/4] Format Format --- fastSum/Dataloader/example.py | 4 +- fastSum/Dataloader/summarizationLoader.py | 299 +++++++++++----------- 2 files changed, 155 insertions(+), 148 deletions(-) diff --git a/fastSum/Dataloader/example.py b/fastSum/Dataloader/example.py index 7c0d190..675aa93 100644 --- a/fastSum/Dataloader/example.py +++ b/fastSum/Dataloader/example.py @@ -1,11 +1,11 @@ from fastNLP.io.file_utils import get_cache_path from summarizationLoader import ArxivLoader -if __name__ == '__main__': +if __name__ == "__main__": # 请设置fastNLP默认cache的存放路径FASTNLP_CACHE_DIR, get_cache_path会获取设置下载的数据位置 # 详细可参考: https://gitee.com/fastnlp/fastNLP/blob/7b4e099c5267efb6a4a88b9d789a0940be05bb56/fastNLP/io/file_utils.py#L228 - print(f'下载的数据位置: {get_cache_path()}') + print(f"下载的数据位置: {get_cache_path()}") ArxivLoader().download() data = ArxivLoader().load() print(data) \ No newline at end of file diff --git a/fastSum/Dataloader/summarizationLoader.py b/fastSum/Dataloader/summarizationLoader.py index c813399..fcfc9db 100644 --- a/fastSum/Dataloader/summarizationLoader.py +++ b/fastSum/Dataloader/summarizationLoader.py @@ -10,7 +10,7 @@ from fastNLP.io.file_utils import get_cache_path, _get_dataset_url, cached_path DATASET_DIR = { # Summarization - 'ami': "AMI.zip", + "ami": "AMI.zip", "arxiv": "Arxiv.zip", "billsum": "BillSum.zip", "cnndm": "CNNDM.zip", @@ -20,7 +20,7 @@ DATASET_DIR = { "reddit tifu": "Reddit TIFU.zip", "samsum": "SAMSum.zip", "wikihow": "WikiHow.zip", - "xsum": "Xsum.zip" + "xsum": "Xsum.zip", } @@ -33,11 +33,7 @@ class SumLoader(JsonLoader): def __init__(self, fields: Optional[Dict[str, str]] = None): if fields is None: - fields = { - 'text': 'text', - 'summary': 'summary', - 'label': Const.TARGET - } + fields = {"text": "text", "summary": "summary", "label": Const.TARGET} super(SumLoader, self).__init__(fields=fields) def load(self, paths: Optional[Path] = None) -> DataBundle: @@ -46,9 +42,9 @@ class SumLoader(JsonLoader): def download(self): default_cache_path = get_cache_path() url = _get_dataset_url(self.DATASET_NAME, DATASET_DIR) - output_dir = cached_path(url_or_filename=url, - cache_dir=default_cache_path, - name='dataset') + output_dir = cached_path( + url_or_filename=url, cache_dir=default_cache_path, name="dataset" + ) # https://gitee.com/fastnlp/fastNLP/blob/7b4e099c5267efb6a4a88b9d789a0940be05bb56/fastNLP/io/file_utils.py#L201 # 如果只有一个文件, get_filepath 返回 filepath + filename # os.path.dirname 反向处理 @@ -77,14 +73,13 @@ class CNNDMLoader(SumLoader): _paths = {} if paths: if os.path.isdir(paths): - if not os.path.isfile( - os.path.join(paths, 'CNNDM.train.label.jsonl')): + if not os.path.isfile(os.path.join(paths, "CNNDM.train.label.jsonl")): raise FileNotFoundError( - f"CNNDM.train.label.jsonl is not found in {paths}") - _paths['train'] = os.path.join(paths, - 'CNNDM.train.label.jsonl') - _paths['dev'] = os.path.join(paths, 'CNNDM.valid.label.jsonl') - _paths['test'] = os.path.join(paths, 'CNNDM.test.label.jsonl') + f"CNNDM.train.label.jsonl is not found in {paths}" + ) + _paths["train"] = os.path.join(paths, "CNNDM.train.label.jsonl") + _paths["dev"] = os.path.join(paths, "CNNDM.valid.label.jsonl") + _paths["test"] = os.path.join(paths, "CNNDM.test.label.jsonl") paths = _paths else: raise NotADirectoryError(f"{paths} is not a valid directory.") @@ -114,14 +109,13 @@ class ArxivLoader(SumLoader): _paths = {} if paths: if os.path.isdir(paths): - if not os.path.isfile( - os.path.join(paths, 'arxiv.train.label.jsonl')): + if not os.path.isfile(os.path.join(paths, "arxiv.train.label.jsonl")): raise FileNotFoundError( - f"arxiv.train.label.jsonl is not found in {paths}") - _paths['train'] = os.path.join(paths, - 'arxiv.train.label.jsonl') - _paths['dev'] = os.path.join(paths, 'arxiv.valid.label.jsonl') - _paths['test'] = os.path.join(paths, 'arxiv.test.label.jsonl') + f"arxiv.train.label.jsonl is not found in {paths}" + ) + _paths["train"] = os.path.join(paths, "arxiv.train.label.jsonl") + _paths["dev"] = os.path.join(paths, "arxiv.valid.label.jsonl") + _paths["test"] = os.path.join(paths, "arxiv.test.label.jsonl") paths = _paths else: raise NotADirectoryError(f"{paths} is not a valid directory.") @@ -152,16 +146,14 @@ class BillSumLoader(SumLoader): if paths: if os.path.isdir(paths): if not os.path.isfile( - os.path.join(paths, 'billsum_us.train.label.jsonl')): + os.path.join(paths, "billsum_us.train.label.jsonl") + ): raise FileNotFoundError( f"billsum_us.train.label.jsonl is not found in {paths}" ) - _paths['train'] = os.path.join(paths, - 'billsum_us.train.label.jsonl') - _paths['dev'] = os.path.join(paths, - 'billsum_ca.valid.label.jsonl') - _paths['test'] = os.path.join(paths, - 'billsum_us.test.label.jsonl') + _paths["train"] = os.path.join(paths, "billsum_us.train.label.jsonl") + _paths["dev"] = os.path.join(paths, "billsum_ca.valid.label.jsonl") + _paths["test"] = os.path.join(paths, "billsum_us.test.label.jsonl") paths = _paths else: raise NotADirectoryError(f"{paths} is not a valid directory.") @@ -192,15 +184,14 @@ class MultiNewsLoader(SumLoader): if paths: if os.path.isdir(paths): if not os.path.isfile( - os.path.join(paths, 'multinews.train.label.jsonl')): + os.path.join(paths, "multinews.train.label.jsonl") + ): raise FileNotFoundError( - f"multinews.train.label.jsonl is not found in {paths}") - _paths['train'] = os.path.join(paths, - 'multinews.train.label.jsonl') - _paths['dev'] = os.path.join(paths, - 'multinews.valid.label.jsonl') - _paths['test'] = os.path.join(paths, - 'multinews.test.label.jsonl') + f"multinews.train.label.jsonl is not found in {paths}" + ) + _paths["train"] = os.path.join(paths, "multinews.train.label.jsonl") + _paths["dev"] = os.path.join(paths, "multinews.valid.label.jsonl") + _paths["test"] = os.path.join(paths, "multinews.test.label.jsonl") paths = _paths else: raise NotADirectoryError(f"{paths} is not a valid directory.") @@ -230,14 +221,13 @@ class PubmedLoader(SumLoader): _paths = {} if paths: if os.path.isdir(paths): - if not os.path.isfile( - os.path.join(paths, 'pubmed.train.label.jsonl')): + if not os.path.isfile(os.path.join(paths, "pubmed.train.label.jsonl")): raise FileNotFoundError( - f"pubmed.train.label.jsonl is not found in {paths}") - _paths['train'] = os.path.join(paths, - 'pubmed.train.label.jsonl') - _paths['dev'] = os.path.join(paths, 'pubmed.valid.label.jsonl') - _paths['test'] = os.path.join(paths, 'pubmed.test.label.jsonl') + f"pubmed.train.label.jsonl is not found in {paths}" + ) + _paths["train"] = os.path.join(paths, "pubmed.train.label.jsonl") + _paths["dev"] = os.path.join(paths, "pubmed.valid.label.jsonl") + _paths["test"] = os.path.join(paths, "pubmed.test.label.jsonl") paths = _paths else: raise NotADirectoryError(f"{paths} is not a valid directory.") @@ -267,14 +257,13 @@ class SAMSumLoader(SumLoader): _paths = {} if paths: if os.path.isdir(paths): - if not os.path.isfile( - os.path.join(paths, 'SAMSum.train.label.jsonl')): + if not os.path.isfile(os.path.join(paths, "SAMSum.train.label.jsonl")): raise FileNotFoundError( - f"SAMSum.train.label.jsonl is not found in {paths}") - _paths['train'] = os.path.join(paths, - 'SAMSum.train.label.jsonl') - _paths['dev'] = os.path.join(paths, 'SAMSum.valid.label.jsonl') - _paths['test'] = os.path.join(paths, 'SAMSum.test.label.jsonl') + f"SAMSum.train.label.jsonl is not found in {paths}" + ) + _paths["train"] = os.path.join(paths, "SAMSum.train.label.jsonl") + _paths["dev"] = os.path.join(paths, "SAMSum.valid.label.jsonl") + _paths["test"] = os.path.join(paths, "SAMSum.test.label.jsonl") paths = _paths else: raise NotADirectoryError(f"{paths} is not a valid directory.") @@ -304,15 +293,13 @@ class WikiHowLoader(SumLoader): _paths = {} if paths: if os.path.isdir(paths): - if not os.path.isfile( - os.path.join(paths, 'wikihow.train.label.jsonl')): + if not os.path.isfile(os.path.join(paths, "wikihow.train.label.jsonl")): raise FileNotFoundError( - f"wikihow.train.label.jsonl is not found in {paths}") - _paths['train'] = os.path.join(paths, - 'wikihow.train.label.jsonl') - _paths['dev'] = os.path.join(paths, 'wikihow.val.label.jsonl') - _paths['test'] = os.path.join(paths, - 'wikihow.test.label.jsonl') + f"wikihow.train.label.jsonl is not found in {paths}" + ) + _paths["train"] = os.path.join(paths, "wikihow.train.label.jsonl") + _paths["dev"] = os.path.join(paths, "wikihow.val.label.jsonl") + _paths["test"] = os.path.join(paths, "wikihow.test.label.jsonl") paths = _paths else: raise NotADirectoryError(f"{paths} is not a valid directory.") @@ -342,13 +329,13 @@ class XsumLoader(SumLoader): _paths = {} if paths: if os.path.isdir(paths): - if not os.path.isfile( - os.path.join(paths, 'xsum.train.label.jsonl')): + if not os.path.isfile(os.path.join(paths, "xsum.train.label.jsonl")): raise FileNotFoundError( - f"xsum.train.label.jsonl is not found in {paths}") - _paths['train'] = os.path.join(paths, 'xsum.train.label.jsonl') - _paths['dev'] = os.path.join(paths, 'xsum.valid.label.jsonl') - _paths['test'] = os.path.join(paths, 'xsum.test.label.jsonl') + f"xsum.train.label.jsonl is not found in {paths}" + ) + _paths["train"] = os.path.join(paths, "xsum.train.label.jsonl") + _paths["dev"] = os.path.join(paths, "xsum.valid.label.jsonl") + _paths["test"] = os.path.join(paths, "xsum.test.label.jsonl") paths = _paths else: raise NotADirectoryError(f"{paths} is not a valid directory.") @@ -372,8 +359,7 @@ class RedditTIFULoader(SumLoader): super(RedditTIFULoader, self).__init__() self.valid_ratio = valid_ratio self.test_ratio = test_ratio - assert tag in ["long", - "short"], "tag not valid (neither long nor short)!" + assert tag in ["long", "short"], "tag not valid (neither long nor short)!" self.tag = tag def load(self, paths: Optional[Path] = None) -> DataBundle: @@ -384,29 +370,32 @@ class RedditTIFULoader(SumLoader): if paths: if os.path.isdir(paths): if not os.path.isfile( - os.path.join(paths, - f"tifu_{self.tag}.all.label.jsonl")): + os.path.join(paths, f"tifu_{self.tag}.all.label.jsonl") + ): raise FileNotFoundError( f"tifu_{self.tag}.all.label.jsonl is not found in {paths}" ) - _split_set(f"tifu_{self.tag}.all.label", - paths, - split_name1="middev", - split_name2="train", - ratio=self.valid_ratio + self.test_ratio) + _split_set( + f"tifu_{self.tag}.all.label", + paths, + split_name1="middev", + split_name2="train", + ratio=self.valid_ratio + self.test_ratio, + ) if self.valid_ratio + self.test_ratio > 0: - _split_set('middev', - paths, - split_name1="test", - split_name2="dev", - ratio=self.test_ratio / - (self.valid_ratio + self.test_ratio)) - _paths['train'] = os.path.join(paths, 'train.jsonl') + _split_set( + "middev", + paths, + split_name1="test", + split_name2="dev", + ratio=self.test_ratio / (self.valid_ratio + self.test_ratio), + ) + _paths["train"] = os.path.join(paths, "train.jsonl") if self.valid_ratio > 0: - _paths['dev'] = os.path.join(paths, 'dev.jsonl') + _paths["dev"] = os.path.join(paths, "dev.jsonl") if self.test_ratio > 0: - _paths['test'] = os.path.join(paths, 'test.jsonl') + _paths["test"] = os.path.join(paths, "test.jsonl") paths = _paths else: raise NotADirectoryError(f"{paths} is not a valid directory.") @@ -429,8 +418,8 @@ class AMILoader(SumLoader): def __init__(self, valid_ratio=0.05, test_ratio=0.05): # AMI 没有 label fields = { - 'text': 'text', - 'summary': 'summary', + "text": "text", + "summary": "summary", } super(AMILoader, self).__init__(fields) @@ -444,27 +433,29 @@ class AMILoader(SumLoader): _paths = {} if paths: if os.path.isdir(paths): - if not os.path.isfile(os.path.join(paths, 'AMI.jsonl')): - raise FileNotFoundError( - f"AMI.jsonl is not found in {paths}") - - _split_set('AMI', - paths, - split_name1="middev", - split_name2="train", - ratio=self.valid_ratio + self.test_ratio) + if not os.path.isfile(os.path.join(paths, "AMI.jsonl")): + raise FileNotFoundError(f"AMI.jsonl is not found in {paths}") + + _split_set( + "AMI", + paths, + split_name1="middev", + split_name2="train", + ratio=self.valid_ratio + self.test_ratio, + ) if self.valid_ratio + self.test_ratio > 0: - _split_set('middev', - paths, - split_name1="test", - split_name2="dev", - ratio=self.test_ratio / - (self.valid_ratio + self.test_ratio)) - _paths['train'] = os.path.join(paths, 'train.jsonl') + _split_set( + "middev", + paths, + split_name1="test", + split_name2="dev", + ratio=self.test_ratio / (self.valid_ratio + self.test_ratio), + ) + _paths["train"] = os.path.join(paths, "train.jsonl") if self.valid_ratio > 0: - _paths['dev'] = os.path.join(paths, 'dev.jsonl') + _paths["dev"] = os.path.join(paths, "dev.jsonl") if self.test_ratio > 0: - _paths['test'] = os.path.join(paths, 'test.jsonl') + _paths["test"] = os.path.join(paths, "test.jsonl") paths = _paths else: raise NotADirectoryError(f"{paths} is not a valid directory.") @@ -487,8 +478,8 @@ class ICSILoader(SumLoader): def __init__(self, valid_ratio=0.05, test_ratio=0.05): # ICSI 没有 label fields = { - 'text': 'text', - 'summary': 'summary', + "text": "text", + "summary": "summary", } super(ICSILoader, self).__init__(fields) self.valid_ratio = valid_ratio @@ -501,27 +492,29 @@ class ICSILoader(SumLoader): _paths = {} if paths: if os.path.isdir(paths): - if not os.path.isfile(os.path.join(paths, 'ICSI.jsonl')): - raise FileNotFoundError( - f"ICSI.jsonl is not found in {paths}") - - _split_set('ICSI', - paths, - split_name1="middev", - split_name2="train", - ratio=self.valid_ratio + self.test_ratio) + if not os.path.isfile(os.path.join(paths, "ICSI.jsonl")): + raise FileNotFoundError(f"ICSI.jsonl is not found in {paths}") + + _split_set( + "ICSI", + paths, + split_name1="middev", + split_name2="train", + ratio=self.valid_ratio + self.test_ratio, + ) if self.valid_ratio + self.test_ratio > 0: - _split_set('middev', - paths, - split_name1="test", - split_name2="dev", - ratio=self.test_ratio / - (self.valid_ratio + self.test_ratio)) - _paths['train'] = os.path.join(paths, 'train.jsonl') + _split_set( + "middev", + paths, + split_name1="test", + split_name2="dev", + ratio=self.test_ratio / (self.valid_ratio + self.test_ratio), + ) + _paths["train"] = os.path.join(paths, "train.jsonl") if self.valid_ratio > 0: - _paths['dev'] = os.path.join(paths, 'dev.jsonl') + _paths["dev"] = os.path.join(paths, "dev.jsonl") if self.test_ratio > 0: - _paths['test'] = os.path.join(paths, 'test.jsonl') + _paths["test"] = os.path.join(paths, "test.jsonl") paths = _paths else: raise NotADirectoryError(f"{paths} is not a valid directory.") @@ -531,25 +524,39 @@ class ICSILoader(SumLoader): return data_bundle -def _split_set(dataset_name, - data_dir, - split_name1="dev", - split_name2="train", - ratio=0.0, - suffix='jsonl', - keep_orig: bool = True): +def _split_set( + dataset_name, + data_dir, + split_name1="dev", + split_name2="train", + ratio=0.0, + suffix="jsonl", + keep_orig: bool = True, +): if ratio == 0: - os.renames(os.path.join(data_dir, f'{dataset_name}.{suffix}'), - os.path.join(data_dir, f'{split_name2}.{suffix}')) + os.renames( + os.path.join(data_dir, f"{dataset_name}.{suffix}"), + os.path.join(data_dir, f"{split_name2}.{suffix}"), + ) return data_dir - if not os.path.exists(os.path.join(data_dir, f'{split_name1}.{suffix}')): + if not os.path.exists(os.path.join(data_dir, f"{split_name1}.{suffix}")): if ratio > 0: assert 0 < ratio < 1, "dev_ratio should be in range (0,1)." try: - with open(os.path.join(data_dir, f'{dataset_name}.{suffix}'), 'r', encoding='utf-8') as f, \ - open(os.path.join(data_dir, f'middle_file.{suffix}'), 'w', encoding='utf-8') as f1, \ - open(os.path.join(data_dir, f'{split_name1}.{suffix}'), 'w', encoding='utf-8') as f2: + with open( + os.path.join(data_dir, f"{dataset_name}.{suffix}"), + "r", + encoding="utf-8", + ) as f, open( + os.path.join(data_dir, f"middle_file.{suffix}"), + "w", + encoding="utf-8", + ) as f1, open( + os.path.join(data_dir, f"{split_name1}.{suffix}"), + "w", + encoding="utf-8", + ) as f2: for line in f: if random.random() < ratio: f2.write(line) @@ -558,13 +565,13 @@ def _split_set(dataset_name, if keep_orig: assert split_name1 != dataset_name and split_name2 != dataset_name else: - os.remove( - os.path.join(data_dir, f'{dataset_name}.{suffix}')) - os.renames(os.path.join(data_dir, f'middle_file.{suffix}'), - os.path.join(data_dir, f'{split_name2}.{suffix}')) + os.remove(os.path.join(data_dir, f"{dataset_name}.{suffix}")) + os.renames( + os.path.join(data_dir, f"middle_file.{suffix}"), + os.path.join(data_dir, f"{split_name2}.{suffix}"), + ) finally: - if os.path.exists( - os.path.join(data_dir, f'middle_file.{suffix}')): - os.remove(os.path.join(data_dir, f'middle_file.{suffix}')) + if os.path.exists(os.path.join(data_dir, f"middle_file.{suffix}")): + os.remove(os.path.join(data_dir, f"middle_file.{suffix}")) - return data_dir + return data_dir \ No newline at end of file -- Gitee