diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py old mode 100644 new mode 100755 diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 26a8f3ca3..fa9466326 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -191,6 +191,7 @@ class PretrainedConfig(object): self.pad_token_id = kwargs.pop("pad_token_id", None) self.eos_token_id = kwargs.pop("eos_token_id", None) self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None) + self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forwar", 0) # task specific arguments self.task_specific_params = kwargs.pop("task_specific_params", None) diff --git a/src/transformers/modeling_albert.py b/src/transformers/modeling_albert.py old mode 100644 new mode 100755 index 264de8cbc..eb0ed4dfb --- a/src/transformers/modeling_albert.py +++ b/src/transformers/modeling_albert.py @@ -43,7 +43,7 @@ from .modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices +from .modeling_utils import PreTrainedModel, apply_chunking_to_forward, find_pruneable_heads_and_indices logger = logging.getLogger(__name__) @@ -69,6 +69,7 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path): """ Load tf checkpoints in a pytorch model.""" try: import re + import numpy as np import tensorflow as tf except ImportError: @@ -286,6 +287,8 @@ class AlbertLayer(nn.Module): super().__init__() self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.attention = AlbertAttention(config) self.ffn = nn.Linear(config.hidden_size, config.intermediate_size) @@ -297,14 +300,20 @@ class AlbertLayer(nn.Module): self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False ): attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions) - ffn_output = self.ffn(attention_output[0]) - ffn_output = self.activation(ffn_output) - ffn_output = self.ffn_output(ffn_output) - ffn_output = self.dropout(ffn_output) + + ffn_output = apply_chunking_to_forward( + self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output[0], + ) hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0]) return (hidden_states,) + attention_output[1:] # add attentions if we output them + def ff_chunk(self, attention_output): + ffn_output = self.ffn(attention_output) + ffn_output = self.activation(ffn_output) + ffn_output = self.ffn_output(ffn_output) + return ffn_output + class AlbertLayerGroup(nn.Module): def __init__(self, config): diff --git a/src/transformers/modeling_bert.py b/src/transformers/modeling_bert.py index 0c956f82b..b22dec2aa 100755 --- a/src/transformers/modeling_bert.py +++ b/src/transformers/modeling_bert.py @@ -424,7 +424,7 @@ class BertLayer(nn.Module): outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( - self.chunk_size_feed_forward, self.seq_len_dim, self.feed_forward_chunk, attention_output + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs return outputs diff --git a/src/transformers/modeling_distilbert.py b/src/transformers/modeling_distilbert.py old mode 100644 new mode 100755 index ca19495e7..7c932a79f --- a/src/transformers/modeling_distilbert.py +++ b/src/transformers/modeling_distilbert.py @@ -44,7 +44,12 @@ from .modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from .modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) logger = logging.getLogger(__name__) @@ -208,6 +213,8 @@ class FFN(nn.Module): def __init__(self, config): super().__init__() self.dropout = nn.Dropout(p=config.dropout) + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim) self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim) assert config.activation in ["relu", "gelu"], "activation ({}) must be in ['relu', 'gelu']".format( @@ -216,6 +223,9 @@ class FFN(nn.Module): self.activation = gelu if config.activation == "gelu" else nn.ReLU() def forward(self, input): + return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input) + + def ff_chunk(self, input): x = self.lin1(input) x = self.activation(x) x = self.lin2(x) diff --git a/src/transformers/modeling_longformer.py b/src/transformers/modeling_longformer.py old mode 100644 new mode 100755 index 033b8bbbc..2623e27f9 --- a/src/transformers/modeling_longformer.py +++ b/src/transformers/modeling_longformer.py @@ -41,7 +41,12 @@ from .modeling_outputs import ( TokenClassifierOutput, ) from .modeling_roberta import RobertaEmbeddings, RobertaLMHead -from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from .modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) logger = logging.getLogger(__name__) @@ -685,6 +690,8 @@ class LongformerLayer(nn.Module): self.attention = LongformerAttention(config, layer_id) self.intermediate = BertIntermediate(config) self.output = BertOutput(config) + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 def forward( self, hidden_states, attention_mask=None, output_attentions=False, @@ -693,11 +700,17 @@ class LongformerLayer(nn.Module): attn_output = self_attn_outputs[0] outputs = self_attn_outputs[1:] # add self attentions if we output attention weights - intermediate_output = self.intermediate(attn_output) - layer_output = self.output(intermediate_output, attn_output) + layer_output = apply_chunking_to_forward( + self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attn_output + ) outputs = (layer_output,) + outputs return outputs + def ff_chunk(self, attn_output): + intermediate_output = self.intermediate(attn_output) + layer_output = self.output(intermediate_output, attn_output) + return layer_output + class LongformerEncoder(nn.Module): def __init__(self, config): diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py old mode 100644 new mode 100755 index b214c6d5e..2e8c5f31d --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -1369,7 +1369,7 @@ class ChunkReformerFeedForward(nn.Module): def forward(self, attention_output): return apply_chunking_to_forward( - self.chunk_size_feed_forward, self.seq_len_dim, self.forward_chunk, attention_output, + self.forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output, ) def forward_chunk(self, hidden_states): @@ -1730,7 +1730,7 @@ class ReformerOnlyLMHead(nn.Module): self.decoder.bias = self.bias def forward(self, hidden_states): - return apply_chunking_to_forward(self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states) + return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states) def forward_chunk(self, hidden_states): hidden_states = self.decoder(hidden_states) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py old mode 100644 new mode 100755 index 62300c1a6..a6ff59ee7 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1519,7 +1519,7 @@ def prune_layer( def apply_chunking_to_forward( - chunk_size: int, chunk_dim: int, forward_fn: Callable[..., torch.Tensor], *input_tensors + forward_fn: Callable[..., torch.Tensor], chunk_size: int, chunk_dim: int, *input_tensors ) -> torch.Tensor: """ This function chunks the :obj:`input_tensors` into smaller input tensor parts of size :obj:`chunk_size` over the @@ -1529,12 +1529,12 @@ def apply_chunking_to_forward( directly applying :obj:`forward_fn` to :obj:`input_tensors`. Args: + forward_fn (:obj:`Callable[..., torch.Tensor]`): + The forward function of the model. chunk_size (:obj:`int`): The chunk size of a chunked tensor: :obj:`num_chunks = len(input_tensors[0]) / chunk_size`. chunk_dim (:obj:`int`): The dimension over which the :obj:`input_tensors` should be chunked. - forward_fn (:obj:`Callable[..., torch.Tensor]`): - The forward function of the model. input_tensors (:obj:`Tuple[torch.Tensor]`): The input tensors of ``forward_fn`` which will be chunked. Returns: @@ -1550,7 +1550,7 @@ def apply_chunking_to_forward( # implement a chunked forward function def forward(self, hidden_states): - return apply_chunking_to_forward(self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states) + return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states) """ assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors) diff --git a/src/transformers/modeling_xlm.py b/src/transformers/modeling_xlm.py old mode 100644 new mode 100755 index 26a514b71..4542c6610 --- a/src/transformers/modeling_xlm.py +++ b/src/transformers/modeling_xlm.py @@ -50,6 +50,7 @@ from .modeling_utils import ( PreTrainedModel, SequenceSummary, SQuADHead, + apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer, ) @@ -212,8 +213,13 @@ class TransformerFFN(nn.Module): self.lin1 = nn.Linear(in_dim, dim_hidden) self.lin2 = nn.Linear(dim_hidden, out_dim) self.act = gelu if config.gelu_activation else F.relu + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 def forward(self, input): + return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input) + + def ff_chunk(self, input): x = self.lin1(input) x = self.act(x) x = self.lin2(x) diff --git a/src/transformers/modeling_xlnet.py b/src/transformers/modeling_xlnet.py old mode 100644 new mode 100755 index ddb655656..dbdd5d70b --- a/src/transformers/modeling_xlnet.py +++ b/src/transformers/modeling_xlnet.py @@ -35,7 +35,14 @@ from .file_utils import ( add_start_docstrings_to_callable, replace_return_docstrings, ) -from .modeling_utils import PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits, PreTrainedModel, SequenceSummary +from .modeling_utils import ( + PoolerAnswerClass, + PoolerEndLogits, + PoolerStartLogits, + PreTrainedModel, + SequenceSummary, + apply_chunking_to_forward, +) logger = logging.getLogger(__name__) @@ -495,6 +502,8 @@ class XLNetLayer(nn.Module): self.rel_attn = XLNetRelativeAttention(config) self.ff = XLNetFeedForward(config) self.dropout = nn.Dropout(config.dropout) + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 def forward( self, @@ -524,12 +533,18 @@ class XLNetLayer(nn.Module): output_h, output_g = outputs[:2] if output_g is not None: - output_g = self.ff(output_g) - output_h = self.ff(output_h) + output_g = apply_chunking_to_forward( + self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, output_g + ) + output_h = apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, output_h) outputs = (output_h, output_g) + outputs[2:] # Add again attentions if there are there return outputs + def ff_chunk(self, output_x): + output_x = self.ff(output_x) + return output_x + class XLNetPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and diff --git a/tests/test_modeling_bert.py b/tests/test_modeling_bert.py old mode 100644 new mode 100755 index 200b56766..fe336df74 --- a/tests/test_modeling_bert.py +++ b/tests/test_modeling_bert.py @@ -26,15 +26,15 @@ from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor if is_torch_available(): from transformers import ( BertConfig, - BertModel, - BertLMHeadModel, BertForMaskedLM, + BertForMultipleChoice, BertForNextSentencePrediction, BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BertForTokenClassification, - BertForMultipleChoice, + BertLMHeadModel, + BertModel, ) from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST @@ -370,7 +370,6 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) - test_chunking = True def setUp(self): self.model_tester = BertModelTester(self) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py old mode 100644 new mode 100755 index 947ab8c02..b37475b3e --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -25,15 +25,15 @@ from transformers.testing_utils import require_multigpu, require_torch, slow, to if is_torch_available(): - import torch import numpy as np + import torch from transformers import ( AdaptiveEmbedding, PretrainedConfig, PreTrainedModel, - BertModel, BertConfig, + BertModel, BERT_PRETRAINED_MODEL_ARCHIVE_LIST, MODEL_FOR_MULTIPLE_CHOICE_MAPPING, MODEL_FOR_QUESTION_ANSWERING_MAPPING, @@ -65,7 +65,6 @@ class ModelTesterMixin: test_resize_embeddings = True test_head_masking = True test_missing_keys = True - test_chunking = False is_encoder_decoder = False def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): @@ -552,9 +551,6 @@ class ModelTesterMixin: def test_feed_forward_chunking(self): (original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common() - if not self.test_chunking: - return - for model_class in self.all_model_classes: torch.manual_seed(0) config = copy.deepcopy(original_config) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index e878c310b..b6e3df069 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -555,7 +555,6 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest test_pruning = False test_headmasking = False test_torchscript = False - test_chunking = True def prepare_kwargs(self): return { @@ -616,7 +615,6 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T test_pruning = False test_headmasking = False test_torchscript = False - test_chunking = True def prepare_kwargs(self): return {