From fa1ddced9e5f91fc6fc5cf4f52654bd7189169c7 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 14 Dec 2020 12:32:26 +0100 Subject: [PATCH] [RAG, Bart] Align RAG, Bart cache with T5 and other models of transformers (#9098) * fix rag * fix slow test * fix past in bart --- src/transformers/models/bart/modeling_bart.py | 32 +++++++++---------- src/transformers/models/rag/modeling_rag.py | 24 +++++++------- tests/test_modeling_rag.py | 16 ++++------ 3 files changed, 33 insertions(+), 39 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 68124fb28..42d753c7b 100644 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -16,7 +16,7 @@ import math import random import warnings -from typing import Dict, Optional, Tuple +from typing import Optional, Tuple import numpy as np import torch @@ -407,7 +407,7 @@ class BartDecoderLayer(nn.Module): hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, encoder_attn_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, attn_mask: Optional[torch.Tensor] = None, output_attentions: Optional[torch.Tensor] = False, ): @@ -416,9 +416,10 @@ class BartDecoderLayer(nn.Module): hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at first position - self_attn_past_key_value = past_key_value[0] if past_key_value is not None else None - hidden_states, self_attn_weights, self_attn_present_key_value = self.self_attn( + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, past_key_value=self_attn_past_key_value, attn_mask=attn_mask, @@ -437,8 +438,8 @@ class BartDecoderLayer(nn.Module): if self.normalize_before: hidden_states = self.encoder_attn_layer_norm(hidden_states) - # cross_attn cached key/values tuple is at second position - cross_attn_past_key_value = past_key_value[1] if past_key_value is not None else None + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, @@ -451,6 +452,9 @@ class BartDecoderLayer(nn.Module): if not self.normalize_before: hidden_states = self.encoder_attn_layer_norm(hidden_states) + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + # Fully Connected residual = hidden_states if self.normalize_before: @@ -463,9 +467,6 @@ class BartDecoderLayer(nn.Module): if not self.normalize_before: hidden_states = self.final_layer_norm(hidden_states) - # make sure decoder uni-directional self-attn at 1st position and cross-attn at 2nd position. - present_key_value = (self_attn_present_key_value, cross_attn_present_key_value) - return ( hidden_states, self_attn_weights, @@ -600,7 +601,7 @@ BART_INPUTS_DOCSTRING = r""" :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - past_key_values (:obj:`Tuple[Tuple[Tuple[torch.Tensor]]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` @@ -857,7 +858,7 @@ class BartDecoder(BartPretrainedModel): - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ - past_key_values (:obj:`Tuple[Tuple[Tuple[torch.Tensor]]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. @@ -897,7 +898,7 @@ class BartDecoder(BartPretrainedModel): raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") # past_key_values_length - past_key_values_length = past_key_values[0][0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale @@ -1284,12 +1285,9 @@ class BartForConditionalGeneration(BartPretrainedModel): @staticmethod def _reorder_cache(past, beam_idx): - def _reorder_buffer(cache: Tuple[torch.Tensor], new_order) -> Dict: - return tuple(past_state.index_select(0, new_order) for past_state in cache) - reordered_past = () for layer_past in past: - reordered_past += (tuple(_reorder_buffer(cache, beam_idx) for cache in layer_past),) + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index e8219c75f..e2c750877 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -1029,6 +1029,10 @@ class RagTokenForGeneration(RagPreTrainedModel): n_docs=None, **kwargs ): + if past is not None: + # if past is defined use only last decoder_input_ids + decoder_input_ids = decoder_input_ids[:, -1:] + return { "input_ids": None, "encoder_outputs": encoder_outputs, @@ -1057,23 +1061,17 @@ class RagTokenForGeneration(RagPreTrainedModel): def _reorder_cache(past, beam_idx): """Reorders cache for generation. BART-inspired but we need to take care of the extra dimension for docs""" - def _reorder_stacked(hidden_states): - n_docs = hidden_states.shape[0] // beam_idx.shape[0] + def _reorder_stacked(hidden_states, new_order): + n_docs = hidden_states.shape[0] // new_order.shape[0] hidden_states = hidden_states.view(-1, n_docs, *hidden_states.shape[1:]) - hidden_states = hidden_states.index_select(0, beam_idx) - return hidden_states.view(-1, *hidden_states.shape[2:]) + hidden_states = hidden_states.index_select(0, new_order) + result = hidden_states.view(-1, *hidden_states.shape[2:]) + return result - def _reorder_buffer(attn_cache): - for k, input_buffer_k in attn_cache.items(): - if input_buffer_k is not None: - attn_cache[k] = _reorder_stacked(input_buffer_k) - return attn_cache - - reordered_past = [] + reordered_past = () for layer_past in past: # get the correct batch idx from decoder layer's batch dim for cross and self-attn - layer_past_new = {attn_key: _reorder_buffer(attn_cache) for attn_key, attn_cache in layer_past.items()} - reordered_past.append(layer_past_new) + reordered_past += (tuple(_reorder_stacked(past_state, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/tests/test_modeling_rag.py b/tests/test_modeling_rag.py index b5efc6fff..382c59c32 100644 --- a/tests/test_modeling_rag.py +++ b/tests/test_modeling_rag.py @@ -535,7 +535,6 @@ class RagDPRBartTest(RagTestMixin, unittest.TestCase): n_docs=self.n_docs, retrieval_vector_size=self.retrieval_vector_size, max_combined_length=self.max_combined_length, - use_cache=False, ) return { @@ -565,7 +564,6 @@ class RagDPRT5Test(RagTestMixin, unittest.TestCase): n_docs=self.n_docs, retrieval_vector_size=self.retrieval_vector_size, max_combined_length=self.max_combined_length, - use_cache=False, ) return { @@ -758,8 +756,8 @@ class RagModelIntegrationTests(unittest.TestCase): generator_tokenizer=rag_decoder_tokenizer, ) - rag_token = self.sequence_model - rag_token.set_retriever(rag_retriever) + rag_sequence = self.sequence_model + rag_sequence.set_retriever(rag_retriever) input_ids = rag_question_encoder_tokenizer( "who sings does he love me with reba", return_tensors="pt" @@ -767,9 +765,9 @@ class RagModelIntegrationTests(unittest.TestCase): input_ids = input_ids.to(torch_device) - output_ids = rag_token.generate( + output_ids = rag_sequence.generate( input_ids, - decoder_start_token_id=rag_token.generator.config.decoder_start_token_id, + decoder_start_token_id=rag_sequence.generator.config.decoder_start_token_id, num_beams=2, num_return_sequences=2, ) @@ -810,7 +808,7 @@ class RagModelIntegrationTests(unittest.TestCase): retriever = RagRetriever.from_pretrained( "facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True ) - rag_sequence = RagTokenForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever).to( + rag_sequence = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever).to( torch_device ) @@ -844,9 +842,9 @@ class RagModelIntegrationTests(unittest.TestCase): " walls of the abdomen", " spodumene", " obama", - " grainger's compound", + " new orleans", " japan", - " old trafford stadium", + " old trafford", ] self.assertListEqual(outputs, EXPECTED_OUTPUTS)