diff --git a/fastSum/BertSum/README.md b/fastSum/BertSum/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d815587393b1927d534e10ed202ca20f23ac7506 --- /dev/null +++ b/fastSum/BertSum/README.md @@ -0,0 +1,2 @@ +1. 参考命令 `CUDA_VISIBLE_DEVICES=0,1 python train_BertSum.py --mode train --save_path save --label_type greedy --batch_size 8` +2. 以上述为例,此时 `data/greedy` 存放了经过 `pre_process.py` 预处理之后的数据;`data/uncased_L-12_H-768_A-12` 存放了 BERT 预训练模型 \ No newline at end of file diff --git a/fastSum/BertSum/dataloader.py b/fastSum/BertSum/dataloader.py index 6af797e4589cbb3c45d076f19634e3b199fc2c4d..22b3cb0483c2afac7ef69f133e4f3b3fa41f40b1 100644 --- a/fastSum/BertSum/dataloader.py +++ b/fastSum/BertSum/dataloader.py @@ -1,114 +1,182 @@ from time import time from datetime import timedelta +from typing import Dict, List -from fastNLP.io.dataset_loader import JsonLoader -from fastNLP.modules.encoder._bert import BertTokenizer +from fastNLP.io import JsonLoader +from fastNLP.modules.tokenizer import BertTokenizer from fastNLP.io.data_bundle import DataBundle from fastNLP.core.const import Const +from fastNLP.core.instance import Instance + class BertData(JsonLoader): - def __init__(self, max_nsents=60, max_ntokens=100, max_len=512): + def __init__(self, fields: Dict[str, str], max_nsents: int = 60, max_ntokens: int = 100, max_len: int = 512): + """ + + :param fields: + :param max_nsents: 每个 article 的最大句子数量限制 + :param max_ntokens: 每个句子的最大单词数限制 + :param max_len: 每个 article 的最大单词数 + """ - fields = {'article': 'article', - 'label': 'label'} - super(BertData, self).__init__(fields=fields) + # fields = { + # 'text': 'text', + # 'label': 'label', + # 'summary': 'summary' # train 里面没有 summary 字样 + # } + super(BertData, self).__init__(fields=fields, dropna=True) self.max_nsents = max_nsents self.max_ntokens = max_ntokens self.max_len = max_len - self.tokenizer = BertTokenizer.from_pretrained('/path/to/uncased_L-12_H-768_A-12') + self.tokenizer = BertTokenizer.from_pretrained('data/uncased_L-12_H-768_A-12') self.cls_id = self.tokenizer.vocab['[CLS]'] self.sep_id = self.tokenizer.vocab['[SEP]'] self.pad_id = self.tokenizer.vocab['[PAD]'] + assert self.pad_id == 0 - def _load(self, paths): - dataset = super(BertData, self)._load(paths) - return dataset + # _load 实际上只用了 JsonLoader 的实现,未作任何修改 + # def _load(self, path): + # dataset = super(BertData, self)._load(path) + # return dataset - def process(self, paths): + def process(self, paths: Dict[str, str]) -> DataBundle: + """ + + :param paths: Dict[name, real_path];real_path 为真正的路径;name 索引或者简略名 + :return: + """ - def truncate_articles(instance, max_nsents=self.max_nsents, max_ntokens=self.max_ntokens): + def truncate_articles(instance: Instance, max_nsents: int = self.max_nsents, max_ntokens: int = self.max_ntokens) -> List[str]: + """ + + :param instance: 某条数据 + :param max_nsents: 详见 __init__ 中 max_nsents + :param max_ntokens: 详见 __init__ 中 max_ntokens + :return: 返回截断后的 article;格式为 List[sentence],sentence 是一个句子(str 格式) + """ article = [' '.join(sent.lower().split()[:max_ntokens]) for sent in instance['article']] return article[:max_nsents] - def truncate_labels(instance): - label = list(filter(lambda x: x < len(instance['article']), instance['label'])) + def truncate_labels(instance: Instance): + """ + 超出 max_nsents 的摘要的指示 label 都删掉 + :param instance: + :return: + """ + label = list(filter(lambda x: x < len(instance['article']), instance['label'])) # label 是 indices_chose,告诉用户选了那些句子作为 summary return label - def bert_tokenize(instance, tokenizer, max_len, pad_value): + def bert_tokenize(instance: Instance, tokenizer: BertTokenizer, max_len: int, pad_value: int) -> List[int]: + """ + 执行 WordPiece 操作 + :param instance: + :param tokenizer: + :param max_len: 详见 __init__ 中 max_len + :param pad_value: [PAD] + :return: + """ article = instance['article'] article = ' [SEP] [CLS] '.join(article) + # tokenizer 会分割输入句子,类似于 pytorch-pretrained-bert 那里的 BertTokenizer 那样执行 WordPiece 操作,把一个 article 转为 List[str] + # 利用 max_len 截断 word_pieces = tokenizer.tokenize(article)[:(max_len - 2)] word_pieces = ['[CLS]'] + word_pieces + ['[SEP]'] token_ids = tokenizer.convert_tokens_to_ids(word_pieces) + # 不足长度的填充 while len(token_ids) < max_len: token_ids.append(pad_value) - assert len(token_ids) == max_len + # assert len(token_ids) == max_len return token_ids - def get_seg_id(instance, max_len, sep_id): - _segs = [-1] + [i for i, idx in enumerate(instance['article']) if idx == sep_id] - segs = [_segs[i] - _segs[i - 1] for i in range(1, len(_segs))] + def get_seg_id(instance: Instance, max_len: int, sep_id: int, pad_value: int) -> List[int]: + """ + Interval Segment Embeddings 生成 + + :param instance: + :param max_len: 详见 __init__ 中 max_len + :param sep_id: [SEP] + :param pad_value: + :return: + """ + _segs = [-1] + [i for i, idx in enumerate(instance['article']) if idx == sep_id] # [CLS, 第 0 句, SEP, CLS, 第 1 句, SEP, ..., 第 n-1 句, SEP] + segs = [_segs[i] - _segs[i - 1] for i in range(1, len(_segs))] # 两两相减,求间隔长度;List[间隔长度] segment_id = [] + # [(CLS, 第 0 句, SEP), (CLS, 第 1 句, SEP), ..., 第 n-1 句, SEP] + # [(0, 0.......0, 0), (1, 1....1, 1), ..., ...] for i, length in enumerate(segs): if i % 2 == 0: segment_id += length * [0] else: segment_id += length * [1] + # 不足长度的填充 while len(segment_id) < max_len: - segment_id.append(0) + segment_id.append(pad_value) return segment_id - def get_cls_id(instance, cls_id): + def get_cls_id(instance: Instance, cls_id: int) -> List[int]: + """ + 找到所有 [CLS] 的在执行了 WordPiece 后的 article 位置 + :param instance: + :param cls_id: [CLS] + :return: + """ classification_id = [i for i, idx in enumerate(instance['article']) if idx == cls_id] return classification_id - def get_labels(instance): + def get_labels(instance: Instance) -> List[int]: + """ + 根据 label (即 indices_chose)对每个句子生成 0 1 标签 + :param instance: + :return: + """ labels = [0] * len(instance['cls_id']) - label_idx = list(filter(lambda x: x < len(instance['cls_id']), instance['label'])) + label_idx = list(filter(lambda x: x < len(instance['cls_id']), instance['label'])) # 我觉得这里 filter 和 truncate_labels 操作重复了 for idx in label_idx: labels[idx] = 1 return labels datasets = {} for name in paths: + # _load 中会调用 _read_json + # _read_json 返回的格式如下:"yield line_idx, _res" 详见 io/file_reader.py;line_idx 表示这是第几行数据的 load 结果;_res 格式为 Dict[fields' key, 某条数据中的对应该 key 的 value] + # _load 处理数据后,首先针对每条数据把 fields' key 处理为 fields' value 得到了 Dict[fields' value, 某条数据中的对应该 key 的 value],然后把该条数据包裹为一个 Instance(core/instance.py), + # 所有的 Instance 使用 append 方法,丢到一个统一的数据容器里面 DataSet(core/dataset.py),并最终返回这个 DataSet datasets[name] = self._load(paths[name]) + + datasets[name].copy_field('text', 'article') - # remove empty samples + # remove empty samples(丢弃空数据) datasets[name].drop(lambda ins: len(ins['article']) == 0 or len(ins['label']) == 0) - # truncate articles + # truncate articles(截断文章) + # new_field_name:将 func 返回的内容放入到 `new_field_name` 这个 field 中,如果名称与已有的 field 相同,则覆盖之前的 field。 datasets[name].apply(lambda ins: truncate_articles(ins, self.max_nsents, self.max_ntokens), new_field_name='article') - # truncate labels + # truncate labels(与上面类似,对照截断 label) + # new_field_name 见上面 datasets[name].apply(truncate_labels, new_field_name='label') - # tokenize and convert tokens to id - datasets[name].apply(lambda ins: bert_tokenize(ins, self.tokenizer, self.max_len, self.pad_id), new_field_name='article') + # tokenize and convert tokens to id(执行 WordPiece 操作) + datasets[name].apply(lambda ins: bert_tokenize(ins, self.tokenizer, self.max_len, pad_value=self.pad_id), new_field_name='article') - # get segment id - datasets[name].apply(lambda ins: get_seg_id(ins, self.max_len, self.sep_id), new_field_name='segment_id') + # get segment id(Interval Segment Embeddings 生成) + datasets[name].apply(lambda ins: get_seg_id(ins, self.max_len, self.sep_id, pad_value=0), new_field_name='segment_id') - # get classification id + # get classification id(提取 [CLS] 的位置) datasets[name].apply(lambda ins: get_cls_id(ins, self.cls_id), new_field_name='cls_id') - # get label + # get label(生成 0 1 标签) datasets[name].apply(get_labels, new_field_name='label') - - # rename filed - datasets[name].rename_field('article', Const.INPUTS(0)) - datasets[name].rename_field('segment_id', Const.INPUTS(1)) - datasets[name].rename_field('cls_id', Const.INPUTS(2)) - datasets[name].rename_field('lbael', Const.TARGET) # set input and target - datasets[name].set_input(Const.INPUTS(0), Const.INPUTS(1), Const.INPUTS(2)) - datasets[name].set_target(Const.TARGET) + datasets[name].set_input('article', 'segment_id', 'cls_id') + datasets[name].set_target('label') - # set paddding value - datasets[name].set_pad_val('article', 0) + # set padding value + datasets[name].set_pad_val('article', self.pad_id) return DataBundle(datasets=datasets) @@ -117,19 +185,25 @@ class BertSumLoader(JsonLoader): def __init__(self): fields = {'article': 'article', - 'segment_id': 'segment_id', - 'cls_id': 'cls_id', - 'label': Const.TARGET - } + 'segment_id': 'segment_id', + 'cls_id': 'cls_id', + 'label': Const.TARGET + } super(BertSumLoader, self).__init__(fields=fields) - def _load(self, paths): - dataset = super(BertSumLoader, self)._load(paths) - return dataset + # _load 实际上只用了 JsonLoader 的实现,未作任何修改 + # def _load(self, path): + # dataset = super(BertSumLoader, self)._load(path) + # return dataset def process(self, paths): + """ + + :param paths: Dict[name, real_path];real_path 为真正的路径;name 索引或者简略名 + :return: + """ - def get_seq_len(instance): + def get_seq_len(instance: Instance): return len(instance['article']) print('Start loading datasets !!!') @@ -138,6 +212,10 @@ class BertSumLoader(JsonLoader): # load datasets datasets = {} for name in paths: + # _load 中会调用 _read_json + # _read_json 返回的格式如下:"yield line_idx, _res" 详见 io/file_reader.py;line_idx 表示这是第几行数据的 load 结果;_res 格式为 Dict[fields' key, 某条数据中的对应该 key 的 value] + # _load 处理数据后,首先针对每条数据把 fields' key 处理为 fields' value 得到了 Dict[fields' value, 某条数据中的对应该 key 的 value],然后把该条数据包裹为一个 Instance(core/instance.py), + # 所有的 Instance 使用 append 方法,丢到一个统一的数据容器里面 DataSet(core/dataset.py),并最终返回这个 DataSet datasets[name] = self._load(paths[name]) datasets[name].apply(get_seq_len, new_field_name='seq_len') @@ -147,8 +225,8 @@ class BertSumLoader(JsonLoader): datasets[name].set_target(Const.TARGET) # set padding value - datasets[name].set_pad_val('article', 0) - datasets[name].set_pad_val('segment_id', 0) + datasets[name].set_pad_val('article', 0) # 这里与 BertData 的 padding 保持一致 + datasets[name].set_pad_val('segment_id', 0) # 这里与 BertData 的 padding 保持一致 datasets[name].set_pad_val('cls_id', -1) datasets[name].set_pad_val(Const.TARGET, 0) diff --git a/fastSum/BertSum/metrics.py b/fastSum/BertSum/metrics.py index 228f6789a9864cf9575cf93f78343596db48c6f7..e92e5a2edc3e649c98b8fa0754be193a698a1f5a 100644 --- a/fastSum/BertSum/metrics.py +++ b/fastSum/BertSum/metrics.py @@ -14,8 +14,10 @@ from pyrouge.utils import log from fastNLP.core.losses import LossBase from fastNLP.core.metrics import MetricBase +# TODO _ROUGE_PATH = '/path/to/RELEASE-1.5.5' + class MyBCELoss(LossBase): def __init__(self, pred=None, target=None, mask=None): @@ -160,10 +162,10 @@ class RougeMetric(MetricBase): for sent in data[i]['abstract']: ref.append(sent) - with open(join(self.dec_path, '{}.dec'.format(i)), 'w') as f: + with open(join(self.dec_path, '{}.dec'.format(i)), 'w', encoding='utf-8') as f: for sent in dec: print(sent, file=f) - with open(join(self.ref_path, '{}.ref'.format(i)), 'w') as f: + with open(join(self.ref_path, '{}.ref'.format(i)), 'w', encoding='utf-8') as f: for sent in ref: print(sent, file=f) diff --git a/fastSum/BertSum/model.py b/fastSum/BertSum/model.py index 34a05495e7b099173b3e7f3985408c58e26b3dc2..18169b562c767a9afbaeb42fb9ea987971ea1453 100644 --- a/fastSum/BertSum/model.py +++ b/fastSum/BertSum/model.py @@ -1,51 +1,66 @@ import torch from torch import nn -from torch.nn import init -from fastNLP.modules.encoder.bert import BertModel +from fastNLP.modules.encoder import BertModel class Classifier(nn.Module): - def __init__(self, hidden_size): + def __init__(self, hidden_size: int): super(Classifier, self).__init__() - self.linear = nn.Linear(hidden_size, 1) + self.linear = nn.Linear(hidden_size, 1) # 抽取式摘要在这里就是对每个句子进行二分类 self.sigmoid = nn.Sigmoid() - def forward(self, inputs, mask_cls): - h = self.linear(inputs).squeeze(-1) # [batch_size, seq_len] + def forward(self, inputs: torch.Tensor, mask_cls: torch.Tensor) -> torch.Tensor: + """ + + :param inputs: shape: (N, doc_len, hidden_size); doc_len 表示某 doc 有几个句子; inputs[i] 表示第 i 个 doc 中每个句子的句向量 + :param mask_cls: shape: (N, doc_len); 用于指正每个 doc 的实际有几个句子(因为填充了), 填充部分为 0,真实部分为 1 + :return: shape: (N, doc_len);表示抽取每个句子的概率大小 + """ + # (N, doc_len, hidden_size=第一层词向量的维度) --linear--> (N, doc_len, 1) --squeeze--> (N, doc_len) + h = self.linear(inputs).squeeze(-1) sent_scores = self.sigmoid(h) * mask_cls.float() return sent_scores class BertSum(nn.Module): - def __init__(self, hidden_size=768): + def __init__(self, hidden_size: int = 768): super(BertSum, self).__init__() self.hidden_size = hidden_size - self.encoder = BertModel.from_pretrained('/path/to/uncased_L-12_H-768_A-12') + self.encoder = BertModel.from_pretrained('data/uncased_L-12_H-768_A-12') self.decoder = Classifier(self.hidden_size) - def forward(self, article, segment_id, cls_id): + def forward(self, article: torch.Tensor, segment_id: torch.Tensor, cls_id: torch.Tensor): + """ + + :param article: shape: (N, seq_len); seq_len 表示最长的文档长度(有几个单词); 0 填充 + :param segment_id: + :param cls_id: [CLS] 的所有位置;shape: (N, doc_len); doc_len 表示某 doc 有几个句子; -1 填充(见 BertSumLoader) + :return: + """ # print(article.device) # print(segment_id.device) # print(cls_id.device) - input_mask = 1 - (article == 0).long() - mask_cls = 1 - (cls_id == -1).long() - assert input_mask.size() == article.size() - assert mask_cls.size() == cls_id.size() + input_mask = 1 - torch.eq(article, 0).long() # 1 有效 + mask_cls = 1 - torch.eq(cls_id, -1).long() # 1 有效 + # assert input_mask.size() == article.size() + # assert mask_cls.size() == cls_id.size() bert_out = self.encoder(article, token_type_ids=segment_id, attention_mask=input_mask) - bert_out = bert_out[0][-1] # last layer + bert_out = bert_out[0][-1] # last layer; shape: (N, sequence_length, 768=hidden_size) + # torch.arange(bert_out.size(0)).unsqueeze(1) shape: (N, 1) ; clss 为 (N, doc_len);配合 broadcast 机制下,选择 cls 的数值(即 bert_out[xxx[i][j], clss[i][j], :]) + # sents_vec shape: (N, doc_len, 768=hidden_size) 相当于每个句子的句向量 sent_emb = bert_out[torch.arange(bert_out.size(0)).unsqueeze(1), cls_id] sent_emb = sent_emb * mask_cls.unsqueeze(-1).float() - assert sent_emb.size() == (article.size(0), cls_id.size(1), self.hidden_size) # [batch_size, seq_len, hidden_size] + # assert sent_emb.size() == (article.size(0), cls_id.size(1), self.hidden_size) # [batch_size, seq_len, hidden_size] - sent_scores = self.decoder(sent_emb, mask_cls) # [batch_size, seq_len] - assert sent_scores.size() == (article.size(0), cls_id.size(1)) + sent_scores = self.decoder(sent_emb, mask_cls) # shape: (N, doc_len);表示抽取每个句子的概率大小 + # assert sent_scores.size() == (article.size(0), cls_id.size(1)) return {'pred': sent_scores, 'mask': mask_cls} diff --git a/fastSum/BertSum/pre_process.py b/fastSum/BertSum/pre_process.py new file mode 100644 index 0000000000000000000000000000000000000000..2cb0706195dc398e6a5cfd7dff68a2e15b454d5a --- /dev/null +++ b/fastSum/BertSum/pre_process.py @@ -0,0 +1,95 @@ +""" +tokenize 花费时间是很长的,预处理保存好 +""" + +import argparse +from os.path import exists +import json +import os + +from fastNLP.core import DataSet + +from dataloader import BertData + + +def save(path: str, dataset: DataSet, encoding='utf-8'): + with open(path, 'w', encoding=encoding) as f: + for instance in dataset: + dict_instance = {key: value for key, value in instance.items()} + f.write(json.dumps(dict_instance)) + f.write('\n') + + +def preprocess_model0(args: argparse.Namespace): + # check if the data_path exists + paths = { + 'train': 'data/' + args.label_type + '/train.label.jsonl', + } + for name in paths: + assert exists(paths[name]) + if not exists(args.save_path): + os.makedirs(args.save_path) + + fields = { + 'text': 'text', + 'label': 'label', + # 'summary': 'summary' # train 里面没有 summary 字样 + } + + # load summarization datasets + datasets = BertData(fields).process(paths) + print('Information of dataset is:') + print(datasets) + + train_set = datasets.datasets['train'] + + save(f'{args.save_path}/bert.train.jsonl', train_set) + + +def preprocess_model1(args: argparse.Namespace): + # check if the data_path exists + paths = { + 'val': 'data/' + args.label_type + '/val.label.jsonl', + 'test': 'data/' + args.label_type + '/test.label.jsonl' + } + for name in paths: + assert exists(paths[name]) + if not exists(args.save_path): + os.makedirs(args.save_path) + + fields = { + 'text': 'text', + 'label': 'label', + 'summary': 'summary' # train 里面没有 summary 字样 + } + + # load summarization datasets + datasets = BertData(fields).process(paths) + print('Information of dataset is:') + print(datasets) + + valid_set = datasets.datasets['val'] + test_set = datasets.datasets['test'] + + save(f'{args.save_path}/bert.val.jsonl', valid_set) + save(f'{args.save_path}/bert.test.jsonl', test_set) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='training/testing of BertSum(liu et al. 2019)' + ) + parser.add_argument('--mode', required=True, + help='preprocessing, training or testing of BertSum', type=str) + + parser.add_argument('--label_type', default='greedy', + help='greedy/limit', type=str) + + parser.add_argument('--save_path', required=True, + help='root of the model', type=str) + + args = parser.parse_args() + assert args.mode == 'preprocess' + + preprocess_model0(args) + preprocess_model1(args) diff --git a/fastSum/BertSum/train_BertSum.py b/fastSum/BertSum/train_BertSum.py index d34fa0b9ab2f3e3fd3b97e3f2b8afc2bfa5b07b3..23870424deb2853b7d85253bfe620424254c9664 100644 --- a/fastSum/BertSum/train_BertSum.py +++ b/fastSum/BertSum/train_BertSum.py @@ -1,39 +1,35 @@ -import sys import argparse import os import json -import torch -from time import time -from datetime import timedelta from os.path import join, exists + +import torch from torch.optim import Adam +from fastNLP.core.trainer import Trainer +from fastNLP.core.tester import Tester from utils import get_data_path, get_rouge_path - from dataloader import BertSumLoader from model import BertSum -from fastNLP.core.optimizer import AdamW from metrics import MyBCELoss, LossMetric, RougeMetric -from fastNLP.core.sampler import BucketSampler from callback import MyCallback, SaveModelCallback -from fastNLP.core.trainer import Trainer -from fastNLP.core.tester import Tester -def configure_training(args): - devices = [int(gpu) for gpu in args.gpus.split(',')] - params = {} - params['label_type'] = args.label_type - params['batch_size'] = args.batch_size - params['accum_count'] = args.accum_count - params['max_lr'] = args.max_lr - params['warmup_steps'] = args.warmup_steps - params['n_epochs'] = args.n_epochs - params['valid_steps'] = args.valid_steps +def configure_training(args: argparse.Namespace): + devices = [int(gpu) for gpu in range(torch.cuda.device_count())] + params = { + 'label_type': args.label_type, + 'batch_size': args.batch_size, + 'accum_count': args.accum_count, + 'max_lr': args.max_lr, + 'warmup_steps': args.warmup_steps, + 'n_epochs': args.n_epochs, + 'valid_steps': args.valid_steps + } return devices, params -def train_model(args): - + +def train_model(args: argparse.Namespace): # check if the data_path and save_path exists data_paths = get_data_path(args.mode, args.label_type) for name in data_paths: @@ -63,7 +59,7 @@ def train_model(args): val_metric = [LossMetric()] # sampler = BucketSampler(num_buckets=32, batch_size=args.batch_size) trainer = Trainer(train_data=train_set, model=model, optimizer=optimizer, - loss=criterion, batch_size=args.batch_size, # sampler=sampler, + loss=criterion, batch_size=args.batch_size, # sampler=sampler, update_every=args.accum_count, n_epochs=args.n_epochs, print_every=100, dev_data=valid_set, metrics=val_metric, metric_key='-loss', validate_every=args.valid_steps, @@ -72,10 +68,11 @@ def train_model(args): print('Start training with the following hyper-parameters:') print(train_params) trainer.train() - -def test_model(args): - models = os.listdir(args.save_path) + +def test_model(args: argparse.Namespace): + + models = os.listdir(args.save_path) # 请确保 path 下面有 *.pt 文件 # load dataset data_paths = get_data_path(args.mode, args.label_type) @@ -85,7 +82,7 @@ def test_model(args): test_set = datasets.datasets['test'] # only need 1 gpu for testing - device = int(args.gpus) + device = 0 args.batch_size = 1 @@ -99,7 +96,7 @@ def test_model(args): # configure testing original_path, dec_path, ref_path = get_rouge_path(args.label_type) test_metric = RougeMetric(data_path=original_path, dec_path=dec_path, - ref_path=ref_path, n_total = len(test_set)) + ref_path=ref_path, n_total=len(test_set)) tester = Tester(data=test_set, model=model, metrics=[test_metric], batch_size=args.batch_size, device=device) tester.test() @@ -112,13 +109,14 @@ if __name__ == '__main__': parser.add_argument('--mode', required=True, help='training or testing of BertSum', type=str) + # CNN/Dailymail 原数据集仅有生成式摘要,所以需要用户生成自己的(人工)抽取式摘要。这个生成方法就是 label_type parser.add_argument('--label_type', default='greedy', help='greedy/limit', type=str) parser.add_argument('--save_path', required=True, help='root of the model', type=str) - # example for gpus input: '0,1,2,3' - parser.add_argument('--gpus', required=True, - help='available gpus for training(separated by commas)', type=str) + + # CUDA_VISIBLE_DEVICES=4,5 + # 来指定 cuda device parser.add_argument('--batch_size', default=18, help='the training batch size', type=int) @@ -141,7 +139,3 @@ if __name__ == '__main__': else: print('Testing process of BertSum !!!') test_model(args) - - - - diff --git a/fastSum/BertSum/utils.py b/fastSum/BertSum/utils.py index 2ba848b75b1046ebef6f4d06dd916379c72b24b7..17421c1d171b6c1b8f65d3bdaeab4c632f30145c 100644 --- a/fastSum/BertSum/utils.py +++ b/fastSum/BertSum/utils.py @@ -1,7 +1,9 @@ import os from os.path import exists +from typing import Dict -def get_data_path(mode, label_type): + +def get_data_path(mode: str, label_type: str) -> Dict[str, str]: paths = {} if mode == 'train': paths['train'] = 'data/' + label_type + '/bert.train.jsonl' @@ -10,6 +12,7 @@ def get_data_path(mode, label_type): paths['test'] = 'data/' + label_type + '/bert.test.jsonl' return paths + def get_rouge_path(label_type): if label_type == 'others': data_path = 'data/' + label_type + '/bert.test.jsonl'