From 9c4f802b6b58f464270570cd630de56ef79fa906 Mon Sep 17 00:00:00 2001 From: Yige Xu Date: Fri, 6 Nov 2020 20:53:10 +0800 Subject: [PATCH] optimize BertEmbedding and RoBERTaEmbedding will early exit if layer != -1 --- fastNLP/embeddings/bert_embedding.py | 24 ++++++++++++++++++------ fastNLP/embeddings/roberta_embedding.py | 24 +++++++++++++++++------- fastNLP/modules/encoder/bert.py | 18 +++++++++++++++--- fastNLP/modules/encoder/roberta.py | 4 ++-- 4 files changed, 52 insertions(+), 18 deletions(-) diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index ec2ba26b..c57d2bef 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -93,7 +93,7 @@ class BertEmbedding(ContextualEmbedding): """ super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) - if word_dropout>0: + if word_dropout > 0: assert vocab.unknown != None, "When word_drop>0, Vocabulary must contain the unknown token." if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: @@ -370,17 +370,29 @@ class _BertWordModel(nn.Module): include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2): super().__init__() - self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name) - self.encoder = BertModel.from_pretrained(model_dir_or_name) - self._max_position_embeddings = self.encoder.config.max_position_embeddings - # 检查encoder_layer_number是否合理 - encoder_layer_number = len(self.encoder.encoder.layer) if isinstance(layers, list): self.layers = [int(l) for l in layers] elif isinstance(layers, str): self.layers = list(map(int, layers.split(','))) else: raise TypeError("`layers` only supports str or list[int]") + assert len(self.layers) > 0, "There is no layer selected!" + + neg_num_output_layer = -16384 + pos_num_output_layer = 0 + for layer in self.layers: + if layer < 0: + neg_num_output_layer = max(layer, neg_num_output_layer) + else: + pos_num_output_layer = max(layer, pos_num_output_layer) + + self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name) + self.encoder = BertModel.from_pretrained(model_dir_or_name, + neg_num_output_layer=neg_num_output_layer, + pos_num_output_layer=pos_num_output_layer) + self._max_position_embeddings = self.encoder.config.max_position_embeddings + # 检查encoder_layer_number是否合理 + encoder_layer_number = len(self.encoder.encoder.layer) for layer in self.layers: if layer < 0: assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ diff --git a/fastNLP/embeddings/roberta_embedding.py b/fastNLP/embeddings/roberta_embedding.py index 90ea1085..ec95abe2 100644 --- a/fastNLP/embeddings/roberta_embedding.py +++ b/fastNLP/embeddings/roberta_embedding.py @@ -196,20 +196,30 @@ class _RobertaWordModel(nn.Module): include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2): super().__init__() - self.tokenizer = RobertaTokenizer.from_pretrained(model_dir_or_name) - self.encoder = RobertaModel.from_pretrained(model_dir_or_name) - # 由于RobertaEmbedding中设置了padding_idx为1, 且使用了非常神奇的position计算方式,所以-2 - self._max_position_embeddings = self.encoder.config.max_position_embeddings - 2 - # 检查encoder_layer_number是否合理 - encoder_layer_number = len(self.encoder.encoder.layer) - if isinstance(layers, list): self.layers = [int(l) for l in layers] elif isinstance(layers, str): self.layers = list(map(int, layers.split(','))) else: raise TypeError("`layers` only supports str or list[int]") + assert len(self.layers) > 0, "There is no layer selected!" + + neg_num_output_layer = -16384 + pos_num_output_layer = 0 + for layer in self.layers: + if layer < 0: + neg_num_output_layer = max(layer, neg_num_output_layer) + else: + pos_num_output_layer = max(layer, pos_num_output_layer) + self.tokenizer = RobertaTokenizer.from_pretrained(model_dir_or_name) + self.encoder = RobertaModel.from_pretrained(model_dir_or_name, + neg_num_output_layer=neg_num_output_layer, + pos_num_output_layer=pos_num_output_layer) + # 由于RobertaEmbedding中设置了padding_idx为1, 且使用了非常神奇的position计算方式,所以-2 + self._max_position_embeddings = self.encoder.config.max_position_embeddings - 2 + # 检查encoder_layer_number是否合理 + encoder_layer_number = len(self.encoder.encoder.layer) for layer in self.layers: if layer < 0: assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ diff --git a/fastNLP/modules/encoder/bert.py b/fastNLP/modules/encoder/bert.py index 7a9ba57e..8d5d576e 100644 --- a/fastNLP/modules/encoder/bert.py +++ b/fastNLP/modules/encoder/bert.py @@ -366,19 +366,28 @@ class BertLayer(nn.Module): class BertEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config, num_output_layer=-1): super(BertEncoder, self).__init__() layer = BertLayer(config) self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) + num_output_layer = num_output_layer if num_output_layer >= 0 else (len(self.layer) + num_output_layer) + self.num_output_layer = max(min(num_output_layer, len(self.layer)), 0) + if self.num_output_layer + 1 < len(self.layer): + logger.info(f'The transformer encoder will early exit after layer-{self.num_output_layer} ' + f'(start from 0)!') def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): all_encoder_layers = [] - for layer_module in self.layer: + for idx, layer_module in enumerate(self.layer): + if idx > self.num_output_layer: + break hidden_states = layer_module(hidden_states, attention_mask) if output_all_encoded_layers: all_encoder_layers.append(hidden_states) if not output_all_encoded_layers: all_encoder_layers.append(hidden_states) + if len(all_encoder_layers) == 0: + all_encoder_layers.append(hidden_states) return all_encoder_layers @@ -435,6 +444,9 @@ class BertModel(nn.Module): self.config = config self.hidden_size = self.config.hidden_size self.model_type = 'bert' + neg_num_output_layer = kwargs.get('neg_num_output_layer', -1) + pos_num_output_layer = kwargs.get('pos_num_output_layer', self.config.num_hidden_layers - 1) + self.num_output_layer = max(neg_num_output_layer + self.config.num_hidden_layers, pos_num_output_layer) if hasattr(config, 'sinusoidal_pos_embds'): self.model_type = 'distilbert' elif 'model_type' in kwargs: @@ -445,7 +457,7 @@ class BertModel(nn.Module): else: self.embeddings = BertEmbeddings(config) - self.encoder = BertEncoder(config) + self.encoder = BertEncoder(config, num_output_layer=self.num_output_layer) if self.model_type != 'distilbert': self.pooler = BertPooler(config) else: diff --git a/fastNLP/modules/encoder/roberta.py b/fastNLP/modules/encoder/roberta.py index da0ab537..10bdb64b 100644 --- a/fastNLP/modules/encoder/roberta.py +++ b/fastNLP/modules/encoder/roberta.py @@ -64,8 +64,8 @@ class RobertaModel(BertModel): undocumented """ - def __init__(self, config): - super().__init__(config) + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) self.embeddings = RobertaEmbeddings(config) self.apply(self.init_bert_weights) -- Gitee