diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index 624d4a9fe..22c441d47 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -1246,6 +1246,7 @@ class LongformerEncoder(nn.Module): hidden_states, attention_mask=None, head_mask=None, + padding_len=0, output_attentions=False, output_hidden_states=False, return_dict=True, @@ -1308,6 +1309,16 @@ class LongformerEncoder(nn.Module): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + # undo padding + if padding_len > 0: + # unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1) + hidden_states = hidden_states[:, :-padding_len] + if output_hidden_states: + all_hidden_states = tuple([state[:, :-padding_len] for state in all_hidden_states]) + + if output_attentions: + all_attentions = tuple([state[:, :, :-padding_len, :] for state in all_attentions]) + if not return_dict: return tuple( v for v in [hidden_states, all_hidden_states, all_attentions, all_global_attentions] if v is not None @@ -1697,6 +1708,7 @@ class LongformerModel(LongformerPreTrainedModel): embedding_output, attention_mask=extended_attention_mask, head_mask=head_mask, + padding_len=padding_len, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -1704,11 +1716,6 @@ class LongformerModel(LongformerPreTrainedModel): sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - # undo padding - if padding_len > 0: - # unpad `sequence_output` because the calling function is expecting a length == input_ids.size(1) - sequence_output = sequence_output[:, :-padding_len] - if not return_dict: return (sequence_output, pooled_output) + encoder_outputs[1:] diff --git a/src/transformers/models/longformer/modeling_tf_longformer.py b/src/transformers/models/longformer/modeling_tf_longformer.py index 458133a9b..4054e8c1a 100644 --- a/src/transformers/models/longformer/modeling_tf_longformer.py +++ b/src/transformers/models/longformer/modeling_tf_longformer.py @@ -1587,13 +1587,23 @@ class TFLongformerEncoder(tf.keras.layers.Layer): all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),) # bzs x num_attn_heads x num_global_attn x seq_len => bzs x num_attn_heads x seq_len x num_global_attn - all_global_attentions = all_global_attentions + (tf.transpose(layer_outputs[2], (0, 1, 3, 2))) + all_global_attentions = all_global_attentions + (tf.transpose(layer_outputs[2], (0, 1, 3, 2)),) # Add last layer if output_hidden_states: hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states all_hidden_states = all_hidden_states + (hidden_states_to_add,) + # undo padding + # unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1) + hidden_states = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states + if output_attentions: + all_attentions = ( + tuple([state[:, :, :-padding_len, :] for state in all_attentions]) + if padding_len > 0 + else all_attentions + ) + if not return_dict: return tuple( v for v in [hidden_states, all_hidden_states, all_attentions, all_global_attentions] if v is not None @@ -1763,11 +1773,6 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - # undo padding - if padding_len > 0: - # unpad `sequence_output` because the calling function is expecting a length == input_ids.size(1) - sequence_output = sequence_output[:, :-padding_len] - if not inputs["return_dict"]: return ( sequence_output, diff --git a/tests/test_modeling_longformer.py b/tests/test_modeling_longformer.py index 2d30bd3ba..a291d0767 100644 --- a/tests/test_modeling_longformer.py +++ b/tests/test_modeling_longformer.py @@ -74,12 +74,6 @@ class LongformerModelTester: # is x + self.attention_window + 1, where x is the number of tokens with global attention) self.key_length = self.attention_window + 2 - # because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for - # the `test_attention_outputs` and `test_hidden_states_output` tests - self.encoder_seq_length = ( - self.seq_length + (self.attention_window - self.seq_length % self.attention_window) % self.attention_window - ) - def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) diff --git a/tests/test_modeling_tf_longformer.py b/tests/test_modeling_tf_longformer.py index be96de22a..7ba7b1a25 100644 --- a/tests/test_modeling_tf_longformer.py +++ b/tests/test_modeling_tf_longformer.py @@ -74,12 +74,6 @@ class TFLongformerModelTester: # because its local attention only attends to `self.attention_window` and one before and one after self.key_length = self.attention_window + 2 - # because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for - # the `test_attention_outputs` and `test_hidden_states_output` tests - self.encoder_seq_length = ( - self.seq_length + (self.attention_window - self.seq_length % self.attention_window) % self.attention_window - ) - def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)