diff --git a/docs/source/index.rst b/docs/source/index.rst index 3dbc6b6dd..1485e9b5b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -296,7 +296,7 @@ TensorFlow and/or Flax. +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | ProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ -| RAG | ✅ | ❌ | ✅ | ❌ | ❌ | +| RAG | ✅ | ❌ | ✅ | ✅ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ diff --git a/docs/source/model_doc/rag.rst b/docs/source/model_doc/rag.rst index 3b7361b16..796b06e73 100644 --- a/docs/source/model_doc/rag.rst +++ b/docs/source/model_doc/rag.rst @@ -94,3 +94,24 @@ RagTokenForGeneration .. autoclass:: transformers.RagTokenForGeneration :members: forward, generate + + +TFRagModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TFRagModel + :members: call + + +TFRagSequenceForGeneration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TFRagSequenceForGeneration + :members: call, generate + + +TFRagTokenForGeneration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TFRagTokenForGeneration + :members: call, generate diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 2b6a03789..ce05881cf 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1130,6 +1130,13 @@ if is_tf_available(): ] ) _import_structure["models.pegasus"].extend(["TFPegasusForConditionalGeneration", "TFPegasusModel"]) + _import_structure["models.rag"].extend( + [ + "TFRagModel", + "TFRagSequenceForGeneration", + "TFRagTokenForGeneration", + ] + ) _import_structure["models.roberta"].extend( [ "TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -2166,6 +2173,7 @@ if TYPE_CHECKING: TFOpenAIGPTPreTrainedModel, ) from .models.pegasus import TFPegasusForConditionalGeneration, TFPegasusModel + from .models.rag import TFRagModel, TFRagSequenceForGeneration, TFRagTokenForGeneration from .models.roberta import ( TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, TFRobertaForMaskedLM, diff --git a/src/transformers/generation_tf_utils.py b/src/transformers/generation_tf_utils.py index 4158cadea..84a7880d0 100644 --- a/src/transformers/generation_tf_utils.py +++ b/src/transformers/generation_tf_utils.py @@ -441,6 +441,7 @@ class TFGenerationMixin: encoder_outputs, attention_mask, use_cache, + **kwargs ): """ Generate sequences for each example without beam search (num_beams == 1). All returned sequence are generated @@ -455,7 +456,7 @@ class TFGenerationMixin: while cur_len < max_length: model_inputs = self.prepare_inputs_for_generation( - input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache + input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **kwargs ) outputs = self(**model_inputs) next_token_logits = outputs[0][:, -1, :] @@ -609,6 +610,7 @@ class TFGenerationMixin: use_cache, forced_bos_token_id, forced_eos_token_id, + **kwargs, ): """Generate sequences for each example with beam search.""" @@ -637,7 +639,7 @@ class TFGenerationMixin: while cur_len < max_length: model_inputs = self.prepare_inputs_for_generation( - input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache + input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **kwargs ) outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size) next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 720a05259..c97032676 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -447,7 +447,7 @@ def input_processing(func, config, input_ids, **kwargs): return output -def load_tf_weights(model, resolved_archive_file): +def load_tf_weights(model, resolved_archive_file, _prefix=None): """ Detect missing and unexpected layers and load the TF weights accordingly to their names and shapes. @@ -493,6 +493,10 @@ def load_tf_weights(model, resolved_archive_file): for weight_name in hdf5_format.load_attributes_from_hdf5_group(h5_layer_object, "weight_names"): # TF names always start with the model name so we ignore it name = "/".join(weight_name.split("/")[1:]) + + if _prefix is not None: + name = _prefix + "/" + name + saved_weights[name] = np.asarray(h5_layer_object[weight_name]) # Add the updated name to the final list for computing missing/unexpected values @@ -501,7 +505,14 @@ def load_tf_weights(model, resolved_archive_file): # Loop over each weights from the instantiated model and compare with the weights from the H5 file for symbolic_weight in symbolic_weights: # TF names always start with the model name so we ignore it - symbolic_weight_name = "/".join(symbolic_weight.name.split("/")[1:]) + if _prefix is not None: + delimeter = len(_prefix.split("/")) + symbolic_weight_name = "/".join( + symbolic_weight.name.split("/")[:delimeter] + + symbolic_weight.name.split("/")[delimeter + 1 :] + ) + else: + symbolic_weight_name = "/".join(symbolic_weight.name.split("/")[1:]) # here we check if the current weight is among the weights from the H5 file # If yes, get the weight_value of the corresponding weight from the H5 file @@ -603,6 +614,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): # a list of re pattern of tensor names to ignore from the weights when loading the model weights # (and avoid unnecessary warnings). _keys_to_ignore_on_load_unexpected = None + _requires_load_weight_prefix = False @property def dummy_inputs(self) -> Dict[str, tf.Tensor]: @@ -741,10 +753,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): def get_prefix_bias_name(self) -> Union[None, str]: """ - Get the concatenated prefix name of the bias from the model name to the parent layer + Get the concatenated _prefix name of the bias from the model name to the parent layer Return: - :obj:`str`: The prefix name of the bias. + :obj:`str`: The _prefix name of the bias. """ warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) return None @@ -1052,7 +1064,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``. - A path to a `directory` containing model weights saved using - :func:`~transformersTF.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. + :func:`~transformers.TFPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. - A path or url to a `PyTorch state_dict save file` (e.g, ``./pt_model/pytorch_model.bin``). In this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the PyTorch model in a @@ -1151,6 +1163,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) mirror = kwargs.pop("mirror", None) + load_weight_prefix = kwargs.pop("load_weight_prefix", None) if is_offline_mode() and not local_files_only: logger.info("Offline mode: forcing local_files_only=True") @@ -1230,6 +1243,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): config.name_or_path = pretrained_model_name_or_path + # composed models, *e.g.* TFRag, require special treatment when it comes to loading + # pre-trained weights. + if cls._requires_load_weight_prefix and model_kwargs.get("name") is not None: + model_kwargs["load_weight_prefix"] = load_weight_prefix + "/" + model_kwargs.get("name") + # Instantiate model. model = cls(config, *model_args, **model_kwargs) @@ -1239,13 +1257,18 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): # Load from a PyTorch checkpoint return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file, allow_missing_keys=True) - model(model.dummy_inputs) # build the network with dummy inputs + # we might need to extend the variable scope for composite models + if load_weight_prefix is not None: + with tf.compat.v1.variable_scope(load_weight_prefix): + model(model.dummy_inputs) # build the network with dummy inputs + else: + model(model.dummy_inputs) # build the network with dummy inputs assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file) # 'by_name' allow us to do transfer learning by skipping/adding layers # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 try: - missing_keys, unexpected_keys = load_tf_weights(model, resolved_archive_file) + missing_keys, unexpected_keys = load_tf_weights(model, resolved_archive_file, load_weight_prefix) except OSError: raise OSError( "Unable to load weights from h5 file. " diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index 43ce55dba..fff403f1a 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -553,7 +553,7 @@ class TFAutoModel(object): @classmethod @replace_list_option_in_docstrings(TF_MODEL_MAPPING, use_model_types=False) - def from_config(cls, config): + def from_config(cls, config, **kwargs): r""" Instantiates one of the base model classes of the library from a configuration. @@ -575,7 +575,7 @@ class TFAutoModel(object): >>> model = TFAutoModel.from_config(config) """ if type(config) in TF_MODEL_MAPPING.keys(): - return TF_MODEL_MAPPING[type(config)](config) + return TF_MODEL_MAPPING[type(config)](config, **kwargs) raise ValueError( "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n" "Model type should be one of {}.".format( @@ -1037,7 +1037,7 @@ class TFAutoModelForSeq2SeqLM: @classmethod @replace_list_option_in_docstrings(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, use_model_types=False) - def from_config(cls, config): + def from_config(cls, config, **kwargs): r""" Instantiates one of the model classes of the library---with a sequence-to-sequence language modeling head---from a configuration. @@ -1061,7 +1061,7 @@ class TFAutoModelForSeq2SeqLM: >>> model = TFAutoModelForSeq2SeqLM.from_config(config) """ if type(config) in TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys(): - return TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING[type(config)](config) + return TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING[type(config)](config, **kwargs) raise ValueError( "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n" "Model type should be one of {}.".format( diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py index ce67fc654..5a1fb467f 100644 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ b/src/transformers/models/bart/modeling_tf_bart.py @@ -1015,13 +1015,16 @@ class TFBartDecoder(tf.keras.layers.Layer): class TFBartMainLayer(tf.keras.layers.Layer): config_class = BartConfig - def __init__(self, config: BartConfig, **kwargs): + def __init__(self, config: BartConfig, load_weight_prefix=None, **kwargs): super().__init__(**kwargs) - self.config = config self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared") - with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name: + # set tf scope correctly + if load_weight_prefix is None: + load_weight_prefix = "model.shared" + + with tf.compat.v1.variable_scope(load_weight_prefix) as shared_abs_scope_name: pass # Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope. @@ -1157,10 +1160,13 @@ class TFBartMainLayer(tf.keras.layers.Layer): BART_START_DOCSTRING, ) class TFBartModel(TFBartPretrainedModel): - def __init__(self, config: BartConfig, *inputs, **kwargs): + + _requires_load_weight_prefix = True + + def __init__(self, config: BartConfig, load_weight_prefix=None, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) - self.model = TFBartMainLayer(config, name="model") + self.model = TFBartMainLayer(config, load_weight_prefix=load_weight_prefix, name="model") def get_encoder(self): return self.model.encoder @@ -1263,9 +1269,11 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode r"model.decoder.embed_tokens.weight", ] - def __init__(self, config, *inputs, **kwargs): + _requires_load_weight_prefix = True + + def __init__(self, config, load_weight_prefix=None, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) - self.model = TFBartMainLayer(config, name="model") + self.model = TFBartMainLayer(config, load_weight_prefix=load_weight_prefix, name="model") self.use_cache = config.use_cache # final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency. self.final_logits_bias = self.add_weight( diff --git a/src/transformers/models/rag/__init__.py b/src/transformers/models/rag/__init__.py index 751553ef5..0c96db875 100644 --- a/src/transformers/models/rag/__init__.py +++ b/src/transformers/models/rag/__init__.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING -from ...file_utils import _BaseLazyModule, is_torch_available +from ...file_utils import _BaseLazyModule, is_tf_available, is_torch_available _import_structure = { @@ -30,6 +30,9 @@ _import_structure = { if is_torch_available(): _import_structure["modeling_rag"] = ["RagModel", "RagSequenceForGeneration", "RagTokenForGeneration"] +if is_tf_available(): + _import_structure["modeling_tf_rag"] = ["TFRagModel", "TFRagSequenceForGeneration", "TFRagTokenForGeneration"] + if TYPE_CHECKING: from .configuration_rag import RagConfig @@ -39,6 +42,9 @@ if TYPE_CHECKING: if is_torch_available(): from .modeling_rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration + if is_tf_available(): + from .modeling_tf_rag import TFRagModel, TFRagSequenceForGeneration, TFRagTokenForGeneration + else: import importlib import os diff --git a/src/transformers/models/rag/modeling_tf_rag.py b/src/transformers/models/rag/modeling_tf_rag.py new file mode 100644 index 000000000..84e0f50c3 --- /dev/null +++ b/src/transformers/models/rag/modeling_tf_rag.py @@ -0,0 +1,1832 @@ +# coding=utf-8 +# Copyright 2020, The RAG Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TFRAG model implementation.""" + +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import numpy as np +import tensorflow as tf + +from ...configuration_utils import PretrainedConfig +from ...file_utils import ModelOutput, add_start_docstrings_to_model_forward, replace_return_docstrings +from ...modeling_tf_outputs import TFBaseModelOutput +from ...modeling_tf_utils import TFCausalLanguageModelingLoss, TFPreTrainedModel, input_processing, shape_list +from ...utils import logging +from .configuration_rag import RagConfig +from .retrieval_rag import RagRetriever + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "RagConfig" + + +@dataclass +class TFRetrievAugLMMarginOutput(ModelOutput): + """ + Base class for retriever augmented marginalized models outputs. + + Args: + loss (:obj:`tf.Tensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): + Language modeling loss. + logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head. The score is possibly marginalized over all documents for + each vocabulary token. + past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size, + num_heads, sequence_length, embed_size_per_head)`). + + Contains precomputed hidden-states (key and values in the attention blocks) of the decoder that can be used + (see :obj:`past_key_values` input) to speed up sequential decoding. + doc_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, config.n_docs)`): + Score between each retrieved document embeddings (see :obj:`retrieved_doc_embeds`) and + :obj:`question_encoder_last_hidden_state`. + retrieved_doc_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, config.n_docs, hidden_size)`, `optional`, returned when `output_retrieved=True`): + Embedded documents retrieved by the retriever. Is used with ``question_encoder_last_hidden_state`` to + compute the ``doc_scores``. + retrieved_doc_ids (:obj:`tf.Tensor` (int32) of shape :obj:`(batch_size, config.n_docs)`, `optional`, returned when `output_retrieved=True`): + The indexes of the embedded documents retrieved by the retriever. + context_input_ids (:obj:`tf.Tensor`(int32) of shape :obj:`(batch_size * config.n_docs, config.max_combined_length)`, `optional`, returned when `output_retrieved=True`): + Input ids post-processed from the retrieved documents and the question encoder input_ids by the retriever. + context_attention_mask (:obj:`tf.Tensor` (int32) of shape :obj:`(batch_size * config.n_docs, config.max_combined_length)`, `optional`, returned when `output_retrieved=True`): + Attention mask post-processed from the retrieved documents and the question encoder :obj:`input_ids` by the + retriever. + question_encoder_last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden states at the output of the last layer of the question encoder pooled output of the + model. + question_enc_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`tf.Tensor` (one for the output of the embeddings and one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden states of the question encoder at the output of each layer plus the initial embedding outputs. + question_enc_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the question encoder, after the attention softmax, used to compute the weighted + average in the self-attention heads. + generator_enc_last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the generator encoder of the model. + generator_enc_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`tf.Tensor` (one for the output of the embeddings and one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden states of the generator encoder at the output of each layer plus the initial embedding outputs. + generator_enc_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the generator encoder, after the attention softmax, used to compute the weighted + average in the self-attention heads. + generator_dec_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`tf.Tensor` (one for the output of the embeddings and one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden states of the generator decoder at the output of each layer plus the initial embedding outputs. + generator_dec_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted + average in the self-attention heads. + """ + + loss: Optional[tf.Tensor] = None + logits: tf.Tensor = None + past_key_values: Optional[List[tf.Tensor]] = None + doc_scores: Optional[tf.Tensor] = None + retrieved_doc_embeds: Optional[tf.Tensor] = None + retrieved_doc_ids: Optional[tf.Tensor] = None + context_input_ids: Optional[tf.Tensor] = None + context_attention_mask: Optional[tf.Tensor] = None + question_encoder_last_hidden_state: Optional[tf.Tensor] = None + question_enc_hidden_states: Optional[Tuple[tf.Tensor]] = None + question_enc_attentions: Optional[Tuple[tf.Tensor]] = None + generator_enc_last_hidden_state: Optional[tf.Tensor] = None + generator_enc_hidden_states: Optional[Tuple[tf.Tensor]] = None + generator_enc_attentions: Optional[Tuple[tf.Tensor]] = None + generator_dec_hidden_states: Optional[Tuple[tf.Tensor]] = None + generator_dec_attentions: Optional[Tuple[tf.Tensor]] = None + + +@dataclass +class TFRetrievAugLMOutput(ModelOutput): + """ + Args: + logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head. The score is possibly marginalized over all documents for + each vocabulary token. + past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size, + num_heads, sequence_length, embed_size_per_head)`). + + Contains precomputed hidden-states (key and values in the attention blocks) of the decoder that can be used + (see :obj:`past_key_values` input) to speed up sequential decoding. + doc_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, config.n_docs)`): + Score between each retrieved document embeddings (see :obj:`retrieved_doc_embeds`) and + :obj:`question_encoder_last_hidden_state`. + retrieved_doc_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, config.n_docs, hidden_size)`, `optional`, returned when `output_retrieved=True`): + Embedded documents retrieved by the retriever. Is used with ``question_encoder_last_hidden_state`` to + compute the ``doc_scores``. + retrieved_doc_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, config.n_docs)`, `optional`, returned when `output_retrieved=True`): + The indexes of the embedded documents retrieved by the retriever. + context_input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size * config.n_docs, config.max_combined_length)`, `optional`, returned when `output_retrieved=True`): + Input ids post-processed from the retrieved documents and the question encoder input_ids by the retriever. + context_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size * config.n_docs, config.max_combined_length)`, `optional`, returned when `output_retrieved=True`): + Attention mask post-processed from the retrieved documents and the question encoder :obj:`input_ids` by the + retriever. + question_encoder_last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden states at the output of the last layer of the question encoder pooled output of the + model. + question_enc_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`tf.Tensor` (one for the output of the embeddings and one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden states of the question encoder at the output of each layer plus the initial embedding outputs. + question_enc_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the question encoder, after the attention softmax, used to compute the weighted + average in the self-attention heads. + generator_enc_last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the generator encoder of the model. + generator_enc_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`tf.Tensor` (one for the output of the embeddings and one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden states of the generator encoder at the output of each layer plus the initial embedding outputs. + generator_enc_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the generator encoder, after the attention softmax, used to compute the weighted + average in the self-attention heads. + generator_dec_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`tf.Tensor` (one for the output of the embeddings and one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden states of the generator decoder at the output of each layer plus the initial embedding outputs. + generator_dec_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted + average in the self-attention heads. + """ + + logits: tf.Tensor = None + past_key_values: Optional[List[tf.Tensor]] = None + doc_scores: Optional[tf.Tensor] = None + retrieved_doc_embeds: Optional[tf.Tensor] = None + retrieved_doc_ids: Optional[tf.Tensor] = None + context_input_ids: Optional[tf.Tensor] = None + context_attention_mask: Optional[tf.Tensor] = None + question_encoder_last_hidden_state: Optional[tf.Tensor] = None + question_enc_hidden_states: Optional[Tuple[tf.Tensor]] = None + question_enc_attentions: Optional[Tuple[tf.Tensor]] = None + generator_enc_last_hidden_state: Optional[tf.Tensor] = None + generator_enc_hidden_states: Optional[Tuple[tf.Tensor]] = None + generator_enc_attentions: Optional[Tuple[tf.Tensor]] = None + generator_dec_hidden_states: Optional[Tuple[tf.Tensor]] = None + generator_dec_attentions: Optional[Tuple[tf.Tensor]] = None + + +class TFRagPreTrainedModel(TFPreTrainedModel): + r""" + RAG models were released with the paper `Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks + `__ by Patrick Lewis, Ethan Perez, Aleksandra Piktus et al. + + RAG is a retriever augmented model and encapsulate three components: a question encoder, a dataset retriever and a + generator, the encoder and generator are trainable while the retriever is just an indexed dataset. + + """ + config_class = RagConfig + base_model_prefix = "rag" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + @classmethod + def from_pretrained_question_encoder_generator( + cls, + question_encoder_pretrained_model_name_or_path: str = None, + generator_pretrained_model_name_or_path: str = None, + retriever: RagRetriever = None, + *model_args, + **kwargs + ) -> TFPreTrainedModel: + r""" + Instantiates an question encoder and a generator from one or two base classes of the library from pretrained + model checkpoints. + + Params: + question_encoder_pretrained_model_name_or_path (:obj: `str`, `optional`): + Information necessary to initiate the question encoder. Can be either: + + - A string with the `shortcut name` of a pretrained model to load from cache or download, e.g., + ``bert-base-uncased``. + - A string with the `identifier name` of a pretrained model that was user-uploaded to our S3, e.g., + ``dbmdz/bert-base-german-cased``. + - A path to a `directory` containing model weights saved using + :func:`~transformers.TFPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. + - A path or url to a `pytorch index checkpoint file` (e.g, ``./pt_model/``). In this case, + ``question_encoder_from_pt`` should be set to :obj:`True`. + + generator_pretrained_model_name_or_path (:obj: `str`, `optional`, defaults to `None`): + Information necessary to initiate the generator. Can be either: + + - A string with the `shortcut name` of a pretrained model to load from cache or download, e.g., + ``t5-small``. + - A string with the `identifier name` of a pretrained model that was user-uploaded to our S3, e.g., + ``facebook/bart-base``. + - A path to a `directory` containing model weights saved using + :func:`~transformers.TFPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. + - A path or url to a `pytorch checkpoint file` (e.g, ``./pt_model/``). In this case, + ``generator_from_pt`` should be set to :obj:`True`. + + model_args (remaining positional arguments, `optional`): + All remaning positional arguments will be passed to the underlying model's ``__init__`` method. + retriever (:class:`~transformers.RagRetriever`, `optional`): + The retriever to use. + kwargs (remaining dictionary of keyword arguments, `optional`): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + ``output_attentions=True``). + + - To update the question_encoder configuration, use the prefix `question_encoder_` for each + configuration parameter. + - To update the generator configuration, use the prefix `generator_` for each configuration parameter. + - To update the parent model configuration, do not use a prefix for each configuration parameter. + + Behaves differently depending on whether a :obj:`config` is provided or automatically loaded. + + Example:: + + >>> from transformers import RagRetriever, TFRagModel + >>> # initialize a RAG from two pretrained models. + >>> model = TFRagModel.from_pretrained_question_encoder_generator('facebook/dpr-question_encoder-single-nq-base', 't5-small') + >>> # alternatively, initialize from pytorch pretrained models can also be done + >>> model = TFRagModel.from_pretrained_question_encoder_generator('facebook/dpr-question_encoder-single-nq-base', "facebook/bart-base", generator_from_pt=True, question_encoder_from_pt=True) + + >>> # saving model after fine-tuning + >>> model.save_pretrained("./rag") + + >>> # load retriever + >>> retriever = RagRetriever.from_pretrained(PATH, index_name="exact", use_dummy_dataset=True) + >>> # load fine-tuned model with retriver + >>> model = TFRagModel.from_pretrained("./rag", retriever=retriever) + """ + + kwargs_question_encoder = { + argument[len("question_encoder_") :]: value + for argument, value in kwargs.items() + if argument.startswith("question_encoder_") + } + + kwargs_generator = { + argument[len("generator_") :]: value + for argument, value in kwargs.items() + if argument.startswith("generator_") + } + + # remove question_encoder, generator kwargs from kwargs + for key in kwargs_question_encoder.keys(): + del kwargs["question_encoder_" + key] + for key in kwargs_generator.keys(): + del kwargs["generator_" + key] + + # Load and initialize the question_encoder and generator + # The distinction between question_encoder and generator at the model level is made + # by the value of the flag `is_generator` that we need to set correctly. + question_encoder = kwargs_question_encoder.pop("model", None) + if question_encoder is None: + assert ( + question_encoder_pretrained_model_name_or_path is not None + ), "If `model` is not defined as an argument, a `question_encoder_pretrained_model_name_or_path` has to be defined" + + from ..auto.modeling_tf_auto import TFAutoModel + + if "config" not in kwargs_question_encoder: + from ..auto.configuration_auto import AutoConfig + + question_encoder_config = AutoConfig.from_pretrained(question_encoder_pretrained_model_name_or_path) + kwargs_question_encoder["config"] = question_encoder_config + + question_encoder = TFAutoModel.from_pretrained( + question_encoder_pretrained_model_name_or_path, + name="question_encoder", + load_weight_prefix=cls.load_weight_prefix, + *model_args, + **kwargs_question_encoder, + ) + + generator = kwargs_generator.pop("generator", None) + if generator is None: + assert ( + generator_pretrained_model_name_or_path is not None + ), "If `generator_model` is not defined as an argument, a `generator_pretrained_model_name_or_path` has to be defined" + + from ..auto.modeling_tf_auto import TFAutoModelForSeq2SeqLM + + if "config" not in kwargs_generator: + from ..auto.configuration_auto import AutoConfig + + generator_config = AutoConfig.from_pretrained(generator_pretrained_model_name_or_path) + kwargs_generator["config"] = generator_config + + generator = TFAutoModelForSeq2SeqLM.from_pretrained( + generator_pretrained_model_name_or_path, + name="generator", + load_weight_prefix=cls.load_weight_prefix, + **kwargs_generator, + ) + + # instantiate config with corresponding kwargs + config = kwargs.get("config", None) + if config is None: + config = RagConfig.from_question_encoder_generator_configs( + question_encoder.config, generator.config, **kwargs + ) + + return cls(question_encoder=question_encoder, generator=generator, config=config, retriever=retriever) + + +RAG_START_DOCSTRING = r""" + + RAG is a sequence-to-sequence model which encapsulates two core components: a question encoder and a generator. + During a forward pass, we encode the input with the question encoder and pass it to the retriever to extract + relevant context documents. The documents are then prepended to the input. Such contextualized inputs is passed to + the generator. + + The question encoder can be any `autoencoding` model, preferably :class:`~transformers.TFDPRQuestionEncoder`, and + the generator can be any `seq2seq` model, preferably :class:`~transformers.TFBartForConditionalGeneration`. + + The model can be initialized with a :class:`~transformers.RagRetriever` for end-to-end generation or used in + combination with the outputs of a retriever in multiple steps---see examples for more details. The model is + compatible any `autoencoding` model as the ``question_encoder`` and any `seq2seq` model with language model head as + the ``generator``. It has been tested with :class:`~transformers.TFDPRQuestionEncoder` as the ``question_encoder`` + and :class:`~transformers.TFBartForConditionalGeneration` as the ``generator``. + + This model inherits from :class:`~transformers.TFPreTrainedModel`. Check the superclass documentation for the + generic methods the library implements for all its model (such as downloading or saving, resizing the input + embeddings, pruning heads etc.) + + This model is also a Tensorflow `tf.keras.Model `__ + subclass. Use it as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to + general usage and behavior. + + The model is in a developing state as it is now fully supports in eager-mode only, and may not be exported in + SavedModel format. + + Args: + config (:class:`~transformers.RagConfig`): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + :meth:`~transformers.TFPreTrainedModel.from_pretrained` method to load the model weights. + question_encoder (:class:`transformers.TFPreTrainedModel`): + An encoder model compatible with the faiss index encapsulated by the ``retriever``. + generator (:class:`transformers.TFPreTrainedModel`): + A seq2seq model used as the generator in the RAG architecture. + retriever (:class:`~transformers.RagRetriever`): + A retriever class encapsulating a faiss index queried to obtain context documents for current inputs. +""" + + +RAG_FORWARD_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. :class:`~transformers.RagConfig`, used to initialize + the model, specifies which generator to use, it also specifies a compatible generator tokenizer. Use that + tokenizer class to obtain the indices. + attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + encoder_outputs (:obj:`tuple(tuple(tf.Tensor)`, `optional`) + Tuple consists of (:obj:`generator_enc_last_hidden_state`, `optional`: :obj:`generator_enc_hidden_states`, + `optional`: :obj:`generator_enc_attentions`). :obj:`generator_enc_last_hidden_state` of shape + :obj:`(batch_size, n_docs * sequence_length, hidden_size)` is a sequence of hidden-states at the output of + the last layer of the generator's encoder. + + Used by the (:class:`~transformers.TFRagModel`) model during decoding. + decoder_input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Provide for generation tasks. `None` by default, construct as per instructions for the generator model + you're using with your RAG instance. + decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will + also be used by default. + past_key_values (:obj:`tuple(tuple(tf.Tensor))`): + Tuple consists of two elements: :obj:`encoder_outputs` of the RAG model (see :obj:`encoder_outputs`) and + :obj:`past_key_values` of the underlying generator. Can be used to speed up decoding. + :obj:`past_key_values` are used in the (:class:`~transformers.RagTokenForGeneration`) model during + decoding. + doc_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, config.n_docs)`): + Score between each retrieved document embeddings (see :obj:`retrieved_doc_embeds`) and + :obj:`question_encoder_last_hidden_state`. If the model has is not initialized with a ``retriever`` + :obj:`doc_scores` has to be provided to the forward pass. :obj:`doc_scores` can be computed via + :obj:`question_encoder_last_hidden_state` and :obj:`retrieved_doc_embeds`, see examples for more + information. + context_input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size * config.n_docs, config.max_combined_length)`, `optional`, returned when `output_retrieved=True`): + Input IDs post-processed from the retrieved documents and the question encoder :obj:`input_ids` by the + retriever. + + If the model has is not initialized with a ``retriever`` :obj:`context_input_ids` has to be provided to the + forward pass. :obj:`context_input_ids` are returned by :meth:`~transformers.RagRetriever.__call__`. + context_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size * config.n_docs, config.max_combined_length)`, `optional`, returned when `output_retrieved=True`): + Attention mask post-processed from the retrieved documents and the question encoder :obj:`input_ids` by the + retriever. + + If the model has is not initialized with a ``retriever`` :obj:`context_attention_mask` has to be provided + to the forward pass. :obj:`context_attention_mask` are returned by + :meth:`~transformers.RagRetriever.__call__`. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + output_retrieved(:obj:`bool`, `optional`): + Whether or not to return the :obj:`retrieved_doc_embeds`, :obj:`retrieved_doc_ids`, + :obj:`context_input_ids` and :obj:`context_attention_mask`. See returned tensors for more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~TFRetrievAugLMOutput` instead of a plain tuple. + n_docs (:obj:`int`, `optional`, defaults to :obj:`config.n_docs`) + Number of documents to retrieve and/or number of documents for which to generate an answer. +""" + + +@add_start_docstrings_to_model_forward(RAG_START_DOCSTRING) +class TFRagModel(TFRagPreTrainedModel): + + load_weight_prefix = "tf_rag_model_1" + + def __init__( + self, + config: Optional[PretrainedConfig] = None, + question_encoder: Optional[TFPreTrainedModel] = None, + generator: Optional[TFPreTrainedModel] = None, + retriever: Optional = None, + load_weight_prefix: Optional[str] = None, + **kwargs, + ): + assert config is not None or ( + question_encoder is not None and generator is not None + ), "Either a configuration or an question_encoder and a generator has to be provided." + + if config is None: + config = RagConfig.from_question_encoder_generator_configs( + question_encoder.config, generator.config, **kwargs + ) + else: + assert isinstance(config, self.config_class), "config: {} has to be of type {}".format( + config, self.config_class + ) + super().__init__(config, **kwargs) + + if question_encoder is None: + from ..auto.modeling_tf_auto import TFAutoModel + + question_encoder = TFAutoModel.from_config(config.question_encoder, name="question_encoder") + + if generator is None: + from ..auto.modeling_tf_auto import TFAutoModelForSeq2SeqLM + + load_weight_prefix = load_weight_prefix if load_weight_prefix is not None else self.load_weight_prefix + generator = TFAutoModelForSeq2SeqLM.from_config( + config.generator, name="generator", load_weight_prefix=load_weight_prefix + "/generator" + ) + + self.retriever = retriever + if self.retriever is not None: + assert isinstance( + retriever, RagRetriever + ), f"`self.retriever` is of type {type(self.retriever)}, but should be of type `RagRetriever`" + self.retriever = retriever + + self.question_encoder = question_encoder + self.generator = generator + + def set_retriever(self, retriever: RagRetriever): + self.retriever = retriever + + @add_start_docstrings_to_model_forward(RAG_FORWARD_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFRetrievAugLMOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + input_ids=None, + attention_mask=None, + encoder_outputs=None, + decoder_input_ids=None, + decoder_attention_mask=None, + past_key_values=None, + doc_scores=None, + context_input_ids=None, + context_attention_mask=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + output_retrieved=None, + n_docs=None, + return_dict=None, + training=False, + **kwargs + ): + r""" + Returns: + + Example:: + + >>> from transformers import RagTokenizer, RagRetriever, RagModel + >>> import torch + + >>> tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base") + >>> retriever = RagRetriever.from_pretrained("facebook/rag-token-base", index_name="exact", use_dummy_dataset=True) + >>> # initialize with RagRetriever to do everything in one forward call + >>> model = TFRagModel.from_pretrained("facebook/rag-token-base", retriever=retriever, from_pt=True) + + >>> input_dict = tokenizer.prepare_seq2seq_batch("How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="tf") + >>> input_ids = input_dict["input_ids"] + >>> outputs = model(input_ids) + + """ + assert ( + "decoder_cached_states" not in kwargs + ), "Please use past_key_values to cache intermediate outputs" # from modeling_tf_bart.py + + inputs = input_processing( + func=self.call, + config=self.config, + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + doc_scores=doc_scores, + context_input_ids=context_input_ids, + context_attention_mask=context_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_retrieved=output_retrieved, + return_dict=return_dict, + n_docs=n_docs, + training=training, + kwargs_call=kwargs, + ) + + # aliasing to minimize code changing + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] + decoder_input_ids = inputs["decoder_input_ids"] + decoder_attention_mask = inputs["decoder_attention_mask"] + encoder_outputs = inputs["encoder_outputs"] + past_key_values = inputs["past_key_values"] + doc_scores = inputs["doc_scores"] + context_input_ids = inputs["context_input_ids"] + context_attention_mask = inputs["context_attention_mask"] + + use_cache = inputs["use_cache"] + output_attentions = inputs["output_attentions"] + output_hidden_states = inputs["output_hidden_states"] + return_dict = inputs["return_dict"] + n_docs = inputs["n_docs"] if inputs["n_docs"] is not None else self.config.n_docs + output_retrieved = inputs["output_retrieved"] + training = inputs["training"] + + # whether retriever has to be used + has_to_retrieve = ( + self.retriever is not None + and (context_input_ids is None or context_attention_mask is None or doc_scores is None) + and encoder_outputs is None + ) + + # encoder_outputs are pre-computed during RAG-token generation + if encoder_outputs is None: + + if has_to_retrieve: + question_enc_outputs = self.question_encoder( + input_ids, attention_mask=attention_mask, return_dict=True, training=training + ) + # see https://github.com/huggingface/transformers/blob/master/src/transformers/models/dpr/modeling_tf_dpr.py#L91 + question_encoder_last_hidden_state = question_enc_outputs[ + 0 + ] # hidden states of question encoder => pooler_output + + retriever_outputs = self.retriever( + input_ids, + question_encoder_last_hidden_state.numpy(), + prefix=self.generator.config.prefix, + n_docs=n_docs, + return_tensors="tf", + ) + context_input_ids, context_attention_mask, retrieved_doc_embeds, retrieved_doc_ids = ( + retriever_outputs["context_input_ids"], + retriever_outputs["context_attention_mask"], + retriever_outputs["retrieved_doc_embeds"], + retriever_outputs["doc_ids"], + ) + + context_input_ids = tf.cast(context_input_ids, tf.int32) + context_attention_mask = tf.cast(context_attention_mask, tf.int32) + retrieved_doc_embeds = tf.cast(retrieved_doc_embeds, tf.float32) + retrieved_doc_ids = tf.cast(retrieved_doc_ids, tf.int32) + + # compute doc_scores + doc_scores = tf.squeeze( + tf.matmul( + tf.expand_dims(question_encoder_last_hidden_state, axis=1), + retrieved_doc_embeds, + transpose_b=True, + ), + axis=1, + ) + + else: + assert ( + context_input_ids is not None + ), "Make sure that `context_input_ids` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function." + assert ( + context_attention_mask is not None + ), "Make sure that `context_attention_mask` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function." + assert ( + doc_scores is not None + ), "Make sure that `doc_scores` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function." + + assert ( + doc_scores is not None + ), "Make sure that `doc_scores` are passed when passing `encoder_outputs` to the forward function." + + assert ( + doc_scores.shape[1] % n_docs + ) == 0, f" The first dimension of `context_input_ids` should be a multiple of `n_docs`={n_docs}, but is {context_input_ids.shape[0]}." + + # Decoder input without context documents + if decoder_input_ids is not None: + decoder_input_ids = tf.repeat(decoder_input_ids, n_docs, axis=0) + + if decoder_attention_mask is not None: + decoder_attention_mask = tf.repeat(decoder_attention_mask, n_docs, axis=0) + + gen_outputs = self.generator( + context_input_ids, + attention_mask=context_attention_mask, + encoder_outputs=encoder_outputs, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + return_dict=True, + training=training, + ) + + if not has_to_retrieve: + question_encoder_last_hidden_state = None + question_enc_hidden_states = None + question_enc_attentions = None + retrieved_doc_embeds = None + retrieved_doc_ids = None + else: + question_enc_hidden_states = question_enc_outputs.hidden_states + question_enc_attentions = question_enc_outputs.attentions + + if not has_to_retrieve or not output_retrieved: + # don't output retrieved docs + context_input_ids = (None,) + context_attention_mask = None + retrieved_doc_embeds = None + retrieved_doc_ids = None + + return TFRetrievAugLMOutput( + logits=gen_outputs.logits, + doc_scores=doc_scores, + past_key_values=gen_outputs.past_key_values, + context_input_ids=context_input_ids, + context_attention_mask=context_attention_mask, + retrieved_doc_embeds=retrieved_doc_embeds, + retrieved_doc_ids=retrieved_doc_ids, + question_encoder_last_hidden_state=question_encoder_last_hidden_state, + question_enc_hidden_states=question_enc_hidden_states, + question_enc_attentions=question_enc_attentions, + generator_enc_last_hidden_state=gen_outputs.encoder_last_hidden_state, + generator_enc_hidden_states=gen_outputs.encoder_hidden_states, + generator_enc_attentions=gen_outputs.encoder_attentions, + generator_dec_hidden_states=gen_outputs.decoder_hidden_states, + generator_dec_attentions=gen_outputs.decoder_attentions, + ) + + +@add_start_docstrings_to_model_forward( + """ + A TF RAG-token model implementation. It performs RAG-token specific marginalization in the forward pass. + """, + RAG_START_DOCSTRING, +) +class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss): + + load_weight_prefix = "tf_rag_token_for_generation_1/rag" + + def __init__( + self, + config: Optional[PretrainedConfig] = None, + question_encoder: Optional[TFPreTrainedModel] = None, + generator: Optional[TFPreTrainedModel] = None, + retriever: Optional = None, + **kwargs, + ): + assert config is not None or ( + question_encoder is not None and generator is not None + ), "Either a configuration or an encoder and a generator has to be provided." + + if config is None: + config = RagConfig.from_question_encoder_generator_configs( + question_encoder.config, generator.config, **kwargs + ) + + super().__init__(config) + + # instantiate model + self.rag = TFRagModel( + config=config, + question_encoder=question_encoder, + generator=generator, + retriever=retriever, + load_weight_prefix=self.load_weight_prefix, + name="rag", + ) + + def set_retriever(self, retriever: RagRetriever): + self.rag.retriever = retriever + + # Adapted from https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_tf_bart.py + def prepare_inputs_for_generation( + self, decoder_input_ids, past, attention_mask, use_cache, doc_scores, n_docs=None, **kwargs + ) -> Dict: + assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}" + + if len(past) == 1: + assert isinstance(past[0], tf.Tensor) + encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0]) + decoder_cached_states = None + else: + assert len(past) == 2 + # Note: encoder_outputs is never changed by Bart as a generator + encoder_outputs, decoder_cached_states = past + + if isinstance(encoder_outputs, tuple): + assert isinstance(encoder_outputs[0], tf.Tensor) + encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0]) + elif isinstance(encoder_outputs, tf.Tensor): + encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs) + + assert ( + decoder_cached_states + ), f"decoder cached states must be truthy. got {decoder_cached_states} from the 2nd element of past" + # if past is defined cut decoder_input_ids to last token + decoder_input_ids = decoder_input_ids[:, -1:] + + assert isinstance( + encoder_outputs, TFBaseModelOutput + ), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}." + return { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "doc_scores": doc_scores, + "context_attention_mask": attention_mask, + "decoder_input_ids": decoder_input_ids, + "past_key_values": decoder_cached_states, + "use_cache": use_cache, # change this to avoid caching (presumably for debugging) + "do_marginalize": True, + "n_docs": n_docs, + } + + @property + def retriever(self): + return self.rag.retriever + + @property + def generator(self): + return self.rag.generator + + @property + def question_encoder(self): + return self.rag.question_encoder + + @staticmethod + def _reorder_cache(past, beam_idx): + """Reorders cache for generation. BART-inspired but we need to take care of the extra dimension for docs""" + + def tf_index_select(input_, dim, indices): + """ + Input: + input_(tensor): input tensor dim(int): dimension indices(list): selected indices list + Output: + mimic of torch_tensor.index_select(dim, indices) + + credit: https://stackoverflow.com/questions/58464790/is-there-an-equivalent-function-of-pytorch-named-index-select-in-tensorflow + """ + shape = shape_list(input_) + if dim == -1: + dim = len(shape) - 1 + shape[dim] = 1 + + tmp = [] + for idx in indices: + begin = [0] * len(shape) + begin[dim] = idx + tmp.append(tf.slice(input_, begin, shape)) + res = tf.concat(tmp, axis=dim) + + return res + + def _reorder_stacked(hidden_states, new_order=beam_idx): + n_docs = hidden_states.shape[0] // new_order.shape[0] + hidden_states = tf.reshape(hidden_states, (-1, n_docs, *hidden_states.shape[1:])) + hidden_states = tf_index_select(hidden_states, 0, new_order) + return tf.reshape(hidden_states, (-1, *hidden_states.shape[2:])) + + if len(past) == 1: + return past + + past_key_values = past[1] + + reordered_past = () + for layer_past in past_key_values: + reordered_past += (tuple(_reorder_stacked(past_state, beam_idx) for past_state in layer_past),) + + return (past[0], reordered_past) + + def marginalize(self, seq_logits, doc_scores, n_docs=None): + n_docs = n_docs if n_docs is not None else self.config.n_docs + + # RAG-token marginalization + seq_logprobs = tf.nn.log_softmax(seq_logits, axis=-1) + seq_logprobs = tf.reshape(seq_logprobs, [seq_logits.shape[0] // n_docs, n_docs, -1, seq_logits.shape[-1]]) + doc_logprobs = tf.nn.log_softmax(doc_scores, axis=1) + doc_logprobs = tf.expand_dims(doc_logprobs, axis=-1) + doc_logprobs = tf.expand_dims(doc_logprobs, axis=-1) # twice + log_prob_sum = seq_logprobs + doc_logprobs + return tf.reduce_logsumexp(log_prob_sum, axis=1) + + @add_start_docstrings_to_model_forward(RAG_FORWARD_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFRetrievAugLMMarginOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + encoder_outputs=None, + past_key_values=None, + doc_scores=None, + context_input_ids=None, + context_attention_mask=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + output_retrieved=None, + n_docs=None, + do_marginalize=None, + labels=None, + reduce_loss=None, + return_dict=None, + training=False, + **kwargs # needs kwargs for generation + ): + r""" + do_marginalize (:obj:`bool`, `optional`): + If :obj:`True`, the logits are marginalized over all documents by making use of + ``torch.nn.functional.log_softmax``. + labels (:obj:`tf.Tensor` or :obj:`np.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the cross entropy classification loss according to Rag-Token model formulation See + https://arxiv.org/pdf/2005.11401.pdf Section 2.1 for details about Rag-Token formulation. Indices should be + in ``[0, ..., config.vocab_size - 1]``. + reduce_loss (:obj:`bool`, `optional`): + Only relevant if ``labels`` is passed. If :obj:`True`, the NLL loss is reduced using the ``tf.Tensor.sum`` + operation. + kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): + Legacy dictionary, which is required so that model can use `generate()` function. + + Returns: + + Example:: + + >>> from transformers import RagTokenizer, RagRetriever, TFRagTokenForGeneration + + >>> tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") + >>> retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True) + >>> # initialize with RagRetriever to do everything in one forward call + >>> model = TFRagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever, from_pt=True) + + >>> input_dict = tokenizer.prepare_seq2seq_batch("How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="tf") + >>> outputs = model(input_dict, output_retrieved=True) + + >>> # or use retriever separately + >>> # 1. Encode + >>> input_ids = input_dict["input_ids"] + >>> question_hidden_states = model.question_encoder(input_ids)[0] + >>> # 2. Retrieve + >>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.numpy(), return_tensors="tf") + >>> doc_scores = tf.squeeze(tf.matmul(tf.expand_dims(question_hidden_states, axis=1), docs_dict["retrieved_doc_embeds"], transpose_b=True), axis=1) + >>> # 3. Forward to generator + >>> outputs = model(inputs=None, context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, decoder_input_ids=input_dict["labels"]) + + >>> # or directly generate + >>> generated = model.generate(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores) + >>> generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True) + """ + + assert ( + "decoder_cached_states" not in kwargs + ), "Please use past_key_values to cache intermediate outputs" # from modeling_tf_bart.py + + inputs = input_processing( + func=self.call, + config=self.config, + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + doc_scores=doc_scores, + context_input_ids=context_input_ids, + context_attention_mask=context_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_retrieved=output_retrieved, + n_docs=n_docs, + do_marginalize=do_marginalize, + labels=labels, + reduce_loss=reduce_loss, + return_dict=return_dict, + training=training, + kwargs_call=kwargs, + ) + + inputs["do_marginalize"] = inputs["do_marginalize"] if inputs["do_marginalize"] else self.config.do_marginalize + inputs["reduce_loss"] = inputs["reduce_loss"] if inputs["reduce_loss"] else self.config.reduce_loss + + if inputs["labels"] is not None: + if inputs["decoder_input_ids"] is None: + inputs["decoder_input_ids"] = inputs["labels"] + inputs["use_cache"] = False + + outputs = self.rag( + inputs["input_ids"], + attention_mask=inputs["attention_mask"], + encoder_outputs=inputs["encoder_outputs"], + decoder_input_ids=inputs["decoder_input_ids"], + decoder_attention_mask=inputs["decoder_attention_mask"], + context_input_ids=inputs["context_input_ids"], + context_attention_mask=inputs["context_attention_mask"], + doc_scores=inputs["doc_scores"], + past_key_values=inputs["past_key_values"], + use_cache=inputs["use_cache"], + output_attentions=inputs["output_attentions"], + output_hidden_states=inputs["output_hidden_states"], + output_retrieved=inputs["output_retrieved"], + n_docs=inputs["n_docs"], + training=inputs["training"], + ) + + loss = None + logits = outputs.logits + if inputs["labels"] is not None: + assert inputs["decoder_input_ids"] is not None + loss = self.get_nll( + outputs.logits, + outputs.doc_scores, + inputs["labels"], + reduce_loss=inputs["reduce_loss"], + epsilon=self.config.label_smoothing, + n_docs=inputs["n_docs"], + ) + + if inputs["do_marginalize"]: + logits = self.marginalize(logits, outputs.doc_scores, inputs["n_docs"]) + + return TFRetrievAugLMMarginOutput( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + doc_scores=outputs.doc_scores, + context_input_ids=outputs.context_input_ids, + context_attention_mask=outputs.context_attention_mask, + retrieved_doc_embeds=outputs.retrieved_doc_embeds, + retrieved_doc_ids=outputs.retrieved_doc_ids, + question_encoder_last_hidden_state=outputs.question_encoder_last_hidden_state, + question_enc_hidden_states=outputs.question_enc_hidden_states, + question_enc_attentions=outputs.question_enc_attentions, + generator_enc_last_hidden_state=outputs.generator_enc_last_hidden_state, + generator_enc_hidden_states=outputs.generator_enc_hidden_states, + generator_enc_attentions=outputs.generator_enc_attentions, + generator_dec_hidden_states=outputs.generator_dec_hidden_states, + generator_dec_attentions=outputs.generator_dec_attentions, + ) + + def generate( + self, + input_ids: Optional[tf.Tensor] = None, + attention_mask: Optional[tf.Tensor] = None, + context_input_ids=None, + context_attention_mask=None, + doc_scores=None, + max_length=None, + min_length=None, + early_stopping=None, + use_cache=None, + num_beams=None, + bos_token_id=None, + pad_token_id=None, + eos_token_id=None, + length_penalty=None, + no_repeat_ngram_size=None, + bad_words_ids=None, + num_return_sequences=None, + decoder_start_token_id=None, + n_docs=None, + **kwargs + ): + """ + Implements TFRAG token decoding. + + Args: + input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + The sequence used as a prompt for the generation. If :obj:`input_ids` is not passed, then + :obj:`context_input_ids` has to be provided. + attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + context_input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size * config.n_docs, config.max_combined_length)`, `optional`, returned when `output_retrieved=True`): + Input IDs post-processed from the retrieved documents and the question encoder :obj:`input_ids` by the + retriever. + + If the model has is not initialized with a ``retriever``, :obj:`context_input_ids` has to be provided + to the forward pass. :obj:`context_input_ids` are returned by + :meth:`~transformers.RagRetriever.__call__`. + context_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size * config.n_docs, config.max_combined_length)`, `optional`, returned when `output_retrieved=True`): + Attention mask post-processed from the retrieved documents and the question encoder :obj:`input_ids` by + the retriever. + + If the model has is not initialized with a ``retriever``, :obj:`context_input_ids` has to be provided + to the forward pass. :obj:`context_input_ids` are returned by + :meth:`~transformers.RagRetriever.__call__`. + doc_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, config.n_docs)`): + Score between each retrieved document embeddings (see :obj:`retrieved_doc_embeds`) and + :obj:`question_encoder_last_hidden_state`. + + If the model has is not initialized with a ``retriever``, :obj:`context_input_ids` has to be provided + to the forward pass. :obj:`context_input_ids` are returned by + :meth:`~transformers.RagRetriever.__call__`. + max_length (:obj:`int`, `optional`, defaults to 20): + The maximum length of the sequence to be generated. + min_length (:obj:`int`, `optional`, defaults to 10): + The minimum length of the sequence to be generated. + early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to stop the beam search when at least ``num_beams`` sentences are finished per batch or + not. + use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not the model should use the past last key/values attentions (if applicable to the model) to + speed up decoding. + pad_token_id (:obj:`int`, `optional`): + The id of the `padding` token. + bos_token_id (:obj:`int`, `optional`): + The id of the `beginning-of-sequence` token. + eos_token_id (:obj:`int`, `optional`): + The id of the `end-of-sequence` token. + length_penalty (:obj:`float`, `optional`, defaults to 1.0): + Exponential penalty to the length. 1.0 means no penalty. + + Set to values < 1.0 in order to encourage the model to generate shorter sequences, to a value > 1.0 in + order to encourage the model to produce longer sequences. + no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0): + If set to int > 0, all ngrams of that size can only occur once. + bad_words_ids(:obj:`List[int]`, `optional`): + List of token ids that are not allowed to be generated. In order to get the tokens of the words that + should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`. + num_beams (:obj:`int`, `optional`, defaults to 1): + Number of beams for beam search. 1 means no beam search. + num_return_sequences(:obj:`int`, `optional`, defaults to 1): + The number of independently computed returned sequences for each element in the batch. Note that this + is not the value we pass to the ``generator``'s `:func:`~transformers.PreTrainedModel.generate` + function, where we set ``num_return_sequences`` to :obj:`num_beams`. + decoder_start_token_id (:obj:`int`, `optional`): + If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token. + n_docs (:obj:`int`, `optional`, defaults to :obj:`config.n_docs`) + Number of documents to retrieve and/or number of documents for which to generate an answer. + + Return: + :obj:`tf.Tensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated + sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all + batches finished early due to the :obj:`eos_token_id`. + """ + # set default parameters + n_docs = n_docs if n_docs is not None else self.config.n_docs + max_length = max_length if max_length is not None else self.config.max_length + min_length = min_length if min_length is not None else self.config.min_length + early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping + use_cache = use_cache if use_cache is not None else self.config.use_cache + num_beams = num_beams if num_beams is not None else self.config.num_beams + bos_token_id = bos_token_id if bos_token_id is not None else self.config.generator.bos_token_id + pad_token_id = pad_token_id if pad_token_id is not None else self.config.generator.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.generator.eos_token_id + length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty + no_repeat_ngram_size = ( + no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size + ) + bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids + num_return_sequences = ( + num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences + ) + decoder_start_token_id = ( + decoder_start_token_id + if decoder_start_token_id is not None + else self.config.generator.decoder_start_token_id + ) + + # retrieve docs + if self.retriever is not None and context_input_ids is None: + question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0] + out = self.retriever( + input_ids, + question_hidden_states.numpy().astype(np.float32), + prefix=self.generator.config.prefix, + n_docs=n_docs, + return_tensors="tf", + ) + context_input_ids, context_attention_mask, retrieved_doc_embeds = ( + out["context_input_ids"], + out["context_attention_mask"], + out["retrieved_doc_embeds"], + ) + + context_input_ids = tf.cast(context_input_ids, tf.int32) + context_attention_mask = tf.cast(context_attention_mask, tf.int32) + retrieved_doc_embeds = tf.cast(retrieved_doc_embeds, tf.float32) + + # compute doc_scores + doc_scores = tf.matmul( + tf.expand_dims(question_hidden_states, axis=1), retrieved_doc_embeds, transpose_b=True + ) + doc_scores = tf.squeeze(doc_scores, axis=1) + + assert ( + context_input_ids.shape[0] % n_docs + ) == 0, f" The first dimension of `context_input_ids` should be a multiple of `n_docs`={n_docs}, but is {context_input_ids.shape[0]}." + + batch_size = context_input_ids.shape[0] // n_docs + + encoder = self.rag.generator.get_encoder() + encoder_outputs = encoder(input_ids=context_input_ids, attention_mask=context_attention_mask, return_dict=True) + + decoder_input_ids = tf.fill( + (batch_size * num_beams, 1), + tf.cast(decoder_start_token_id, tf.int32), + ) + last_hidden_state = encoder_outputs["last_hidden_state"] + + def extend_enc_output(tensor, num_beams=None): + """ + Broadcast tensor with `num_beams` replica, with correct order Input: tensor of shape (batch_size*n_docs , + d) Output: tensor of shape (batch_size*num_beams*n_docs , d) + """ + + # expand batch_size & num_beam dimensions + d_shape_list = tensor.shape[1:] + + # split n_docs dimensions + new_shape = (batch_size, 1, n_docs) + d_shape_list + tensor = tf.reshape(tensor, new_shape) + + # repeat same last hidden states over `num_beams` dimension + new_shape = (batch_size, num_beams, n_docs) + d_shape_list + tensor = tf.broadcast_to(tensor, new_shape) + + # merge `batch_size`, `num_beams`, `num_docs` dims again + new_shape = (batch_size * num_beams * n_docs,) + d_shape_list + return tf.reshape(tensor, new_shape) + + # correctly extend last_hidden_state and attention mask + context_attention_mask = extend_enc_output(context_attention_mask, num_beams=num_beams) + encoder_outputs["last_hidden_state"] = extend_enc_output(last_hidden_state, num_beams=num_beams) + + doc_scores = tf.repeat(doc_scores, num_beams, axis=0) + + # define start_len & additional parameters + cur_len = 1 + vocab_size = self.config.generator.vocab_size + kwargs["doc_scores"] = doc_scores + kwargs["encoder_outputs"] = encoder_outputs + kwargs["n_docs"] = n_docs + + # not needed. TODO(PVP): change after generate refactor + do_sample = False + temperature = self.config.temperature + top_k = self.config.top_k + top_p = self.config.top_p + repetition_penalty = self.config.repetition_penalty + + if num_beams > 1: + return self._generate_beam_search( + decoder_input_ids, + cur_len=cur_len, + max_length=max_length, + min_length=min_length, + do_sample=do_sample, + early_stopping=early_stopping, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + bad_words_ids=bad_words_ids, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + batch_size=batch_size, + num_return_sequences=num_return_sequences, + length_penalty=length_penalty, + num_beams=num_beams, + vocab_size=vocab_size, + attention_mask=context_attention_mask, + use_cache=use_cache, + forced_bos_token_id=None, + forced_eos_token_id=None, + **kwargs, # encoder_outputs is here as in Pytorch's version + ) + else: + return self._generate_no_beam_search( + decoder_input_ids, + cur_len=cur_len, + max_length=max_length, + min_length=min_length, + do_sample=do_sample, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + bad_words_ids=bad_words_ids, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + batch_size=batch_size, + vocab_size=vocab_size, + attention_mask=context_attention_mask, + use_cache=use_cache, + forced_bos_token_id=None, + forced_eos_token_id=None, + **kwargs, # encoder_outputs is here as in Pytorch's version + ) + + def get_input_embeddings(self): + return self.rag.generator.get_input_embeddings() + + def get_output_embeddings(self): + return self.rag.generator.get_output_embeddings() + + # Adapted from tf_t5's & tf_bart's _shift_right + def shift_tokens_right(self, input_ids, start_token_id=None): + """Shift input ids one token to the right, and pad with start_token_id""" + + if start_token_id is None: + start_token_id = self.generator.config.decoder_start_token_id + assert ( + start_token_id is not None + ), "self.generator.config.decoder_start_token_id has to be defined. In Rag we commonly use Bart as generator, see Bart docs for more information" + + pad_token_id = self.generator.config.pad_token_id + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + + shifted_input_ids = tf.cast(input_ids, tf.int32) + shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1) + start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), start_token_id) + shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1) + + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids = tf.where( + shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids + ) + + # "Verify that `labels` has only positive values and -100" + assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32)) + + # Make sure the assertion op is called by wrapping the result in an identity no-op + with tf.control_dependencies([assert_gte0]): + shifted_input_ids = tf.identity(shifted_input_ids) + + return shifted_input_ids + + # nll stands for 'negative log likelihood' + def get_nll(self, seq_logits, doc_scores, target, reduce_loss=False, epsilon=0.0, n_docs=None): + n_docs = n_docs if n_docs is not None else self.config.n_docs + # shift tokens left (from original Pytorch's version) + + target = tf.concat([target[:, 1:], tf.fill([target.shape[0], 1], self.config.generator.pad_token_id)], axis=1) + rag_logprobs = self.marginalize(seq_logits, doc_scores, n_docs) + loss = self.compute_loss(target, rag_logprobs, from_logits=True, reduce_loss=reduce_loss) + + return loss + + # Adopted modeling_tf_bart + add smooth_loss to match with pytorch version + def compute_loss(self, labels, y_pred, smooth_epsilon=0.0, from_logits=True, reduce_loss=False): + """CrossEntropyLoss that ignores pad tokens""" + loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( + from_logits=True, + reduction=tf.keras.losses.Reduction.SUM, + ) + + if from_logits is False: # convert to logits + eps = 1e-9 + y_pred = tf.clip_by_value(y_pred, clip_value_min=eps, clip_value_max=1 - eps) + y_pred = tf.math.log(y_pred) + + logits = y_pred + melted_labels = tf.reshape(labels, (-1,)) + active_loss = tf.not_equal(melted_labels, self.config.generator.pad_token_id) + + reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, logits.shape[2])), active_loss) + labels = tf.boolean_mask(melted_labels, active_loss) + nll_loss = loss_fn(labels, reduced_logits) + + smooth_loss = -tf.reduce_sum(reduced_logits, axis=-1) + smooth_loss = tf.reduce_sum(smooth_loss) # sum and squeeze like torch + eps_i = smooth_epsilon / reduced_logits.shape[-1] + + loss = (1.0 - smooth_epsilon) * nll_loss + eps_i * smooth_loss + + return loss + + +@add_start_docstrings_to_model_forward( + """ + A TF RAG-sequence model implementation. It performs RAG-sequence specific marginalization in the forward pass. + """, + RAG_START_DOCSTRING, +) +class TFRagSequenceForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss): + + load_weight_prefix = "tf_rag_sequence_for_generation_1/rag" + + def __init__( + self, + config: Optional[PretrainedConfig] = None, + question_encoder: Optional[TFPreTrainedModel] = None, + generator: Optional[TFPreTrainedModel] = None, + retriever: Optional = None, + **kwargs, + ): + assert config is not None or ( + question_encoder is not None and generator is not None + ), "Either a configuration or an encoder and a generator has to be provided." + + if config is None: + config = RagConfig.from_question_encoder_generator_configs( + question_encoder.config, generator.config, **kwargs + ) + + super().__init__(config) + + # instantiate model + self.rag = TFRagModel( + config=config, + question_encoder=question_encoder, + generator=generator, + retriever=retriever, + load_weight_prefix=self.load_weight_prefix, + name="rag", + ) + + def set_retriever(self, retriever: RagRetriever): + self.rag.retriever = retriever + + @property + def retriever(self): + return self.rag.retriever + + @property + def generator(self): + return self.rag.generator + + @property + def question_encoder(self): + return self.rag.question_encoder + + @add_start_docstrings_to_model_forward(RAG_FORWARD_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFRetrievAugLMMarginOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + encoder_outputs=None, + past_key_values=None, + doc_scores=None, + context_input_ids=None, + context_attention_mask=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + output_retrieved=None, + n_docs=None, + exclude_bos_score=None, + labels=None, + reduce_loss=None, + return_dict=None, + training=False, + **kwargs # needs kwargs for generation + ): + r""" + exclude_bos_score (:obj:`bool`, `optional`): + Only relevant if ``labels`` is passed. If :obj:`True`, the score of the BOS token is disregarded when + computing the loss. + labels (:obj:`tf.Tensor` or :obj:`np.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the cross entropy classification loss according to Rag-Sequence model formulation See + https://arxiv.org/pdf/2005.11401.pdf Section 2.1 for details about Rag-Sequence formulation. Indices should + be in ``[0, ..., config.vocab_size - 1]``. + reduce_loss (:obj:`bool`, `optional`): + Only relevant if ``labels`` is passed. If :obj:`True`, the NLL loss is reduced using the ``tf.Tensor.sum`` + operation. + kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): + Legacy dictionary, which is required so that model can use `generate()` function. + + Returns: + + Example:: + + >>> from transformers import RagTokenizer, RagRetriever, TFRagSequenceForGeneration + + >>> tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") + >>> retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True) + >>> # initialize with RagRetriever to do everything in one forward call + >>> model = TFRagRagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever, from_pt=True) + + >>> input_dict = tokenizer.prepare_seq2seq_batch("How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="tf") + >>> outputs = model(input_dict, output_retrieved=True) + + >>> # or use retriever separately + >>> # 1. Encode + >>> input_ids = input_dict["input_ids"] + >>> question_hidden_states = model.question_encoder(input_ids)[0] + >>> # 2. Retrieve + >>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.numpy(), return_tensors="tf") + >>> doc_scores = tf.squeeze(tf.matmul(tf.expand_dims(question_hidden_states, axis=1), docs_dict["retrieved_doc_embeds"], transpose_b=True), axis=1) + >>> # 3. Forward to generator + >>> outputs = model(inputs=None, context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, decoder_input_ids=input_dict["labels"]) + + >>> # or directly generate + >>> generated = model.generate(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores) + >>> generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True) + """ + + assert ( + "decoder_cached_states" not in kwargs + ), "Please use past_key_values to cache intermediate outputs" # from modeling_tf_bart.py + + inputs = input_processing( + func=self.call, + config=self.config, + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + doc_scores=doc_scores, + context_input_ids=context_input_ids, + context_attention_mask=context_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_retrieved=output_retrieved, + n_docs=n_docs, + exclude_bos_score=exclude_bos_score, + labels=labels, + reduce_loss=reduce_loss, + training=training, + return_dict=return_dict, + kwargs_call=kwargs, + ) + + inputs["exclude_bos_score"] = ( + inputs["exclude_bos_score"] if inputs["exclude_bos_score"] else self.config.exclude_bos_score + ) + inputs["reduce_loss"] = inputs["reduce_loss"] if inputs["reduce_loss"] else self.config.reduce_loss + + if inputs["labels"] is not None: + if inputs["decoder_input_ids"] is None: + inputs["decoder_input_ids"] = inputs["labels"] + inputs["use_cache"] = False + + outputs = self.rag( + inputs["input_ids"], + attention_mask=inputs["attention_mask"], + encoder_outputs=inputs["encoder_outputs"], + decoder_input_ids=inputs["decoder_input_ids"], + decoder_attention_mask=inputs["decoder_attention_mask"], + context_input_ids=inputs["context_input_ids"], + context_attention_mask=inputs["context_attention_mask"], + doc_scores=inputs["doc_scores"], + past_key_values=inputs["past_key_values"], + use_cache=inputs["use_cache"], + output_attentions=inputs["output_attentions"], + output_hidden_states=inputs["output_hidden_states"], + output_retrieved=inputs["output_retrieved"], + n_docs=inputs["n_docs"], + training=inputs["training"], + ) + + loss = None + if inputs["labels"] is not None: + loss = self.get_nll( + outputs.logits, + outputs.doc_scores, + inputs["labels"], + reduce_loss=inputs["reduce_loss"], + epsilon=self.config.label_smoothing, + n_docs=inputs["n_docs"], + ) + + return TFRetrievAugLMMarginOutput( + loss=loss, + logits=outputs.logits, + doc_scores=outputs.doc_scores, + past_key_values=outputs.past_key_values, + context_input_ids=outputs.context_input_ids, + context_attention_mask=outputs.context_attention_mask, + retrieved_doc_embeds=outputs.retrieved_doc_embeds, + retrieved_doc_ids=outputs.retrieved_doc_ids, + question_encoder_last_hidden_state=outputs.question_encoder_last_hidden_state, + question_enc_hidden_states=outputs.question_enc_hidden_states, + question_enc_attentions=outputs.question_enc_attentions, + generator_enc_last_hidden_state=outputs.generator_enc_last_hidden_state, + generator_enc_hidden_states=outputs.generator_enc_hidden_states, + generator_enc_attentions=outputs.generator_enc_attentions, + generator_dec_hidden_states=outputs.generator_dec_hidden_states, + generator_dec_attentions=outputs.generator_dec_attentions, + ) + + def get_nll( + self, seq_logits, doc_scores, target, reduce_loss=False, epsilon=0.0, exclude_bos_score=False, n_docs=None + ): + # shift tokens left + target = tf.concat([target[:, 1:], tf.fill([target.shape[0], 1], self.config.generator.pad_token_id)], axis=1) + + # bos_token_id is None for T5 + bos_token_id = self.config.bos_token_id or self.config.generator.bos_token_id + n_docs = n_docs if n_docs is not None else self.config.n_docs + equal_bos_token_id_all = tf.reduce_all(tf.equal(target[:, 0], bos_token_id)) + use_bos = bos_token_id is not None and equal_bos_token_id_all + + def _mask_pads(ll, smooth_obj): + pad_mask = tf.equal(target, self.config.generator.pad_token_id) + if tf.reduce_any(pad_mask): + ll = tf.where(pad_mask, 0.0, ll) + smooth_obj = tf.where(pad_mask, 0.0, smooth_obj) + return tf.squeeze(ll, axis=-1), tf.squeeze(smooth_obj, axis=-1) + + # seq_logits.shape = (batch*n_docs, tgt_len , vocabs) + seq_logprobs = tf.nn.log_softmax(seq_logits, axis=-1) + seq_logprobs = tf.reshape( + seq_logprobs, (seq_logits.shape[0] // n_docs, n_docs, -1, seq_logits.shape[-1]) + ) # (batch_size, n_docs, tgt_len, vocabs) + doc_logprobs = tf.nn.log_softmax(doc_scores, axis=1) + doc_logprobs = tf.expand_dims(doc_logprobs, axis=-1) + doc_logprobs = tf.expand_dims(doc_logprobs, axis=-1) # done twice to get 4-D + + # RAG-sequence marginalization + first_token_scores = seq_logprobs[:, :, :1, :] + second_token_scores = seq_logprobs[:, :, 1:2, :] + remainder = seq_logprobs[:, :, 2:, :] + rag_logprobs = tf.concat([first_token_scores, second_token_scores + doc_logprobs, remainder], axis=2) + + # calculate loss + target = tf.expand_dims(target, axis=1) # n_docs dimension + target = tf.expand_dims(target, axis=-1) # logits dimension + target = tf.repeat(target, n_docs, axis=1) + assert len(target.shape) == len(rag_logprobs.shape) + + # last-axis gathering only - use 2D-reshape-trick for Torch's style nD gathering + def torch_gather(param, id_tensor): + # 2d-gather torch equivalent: https://stackoverflow.com/questions/52129909/tensorflow-equivalent-of-torch-gather + def gather2d(target, id_tensor): + idx = tf.stack([tf.range(tf.shape(id_tensor)[0]), id_tensor[:, 0]], axis=-1) + result = tf.gather_nd(target, idx) + return tf.expand_dims(result, axis=-1) + + target = tf.reshape(param, (-1, param.shape[-1])) # reshape 2D + target_shape = id_tensor.shape + + id_tensor = tf.reshape(id_tensor, (-1, 1)) # also 2D-index + result = gather2d(target, id_tensor) + return tf.reshape(result, target_shape) + + ll = torch_gather(rag_logprobs, id_tensor=target) + smooth_obj = tf.reduce_sum(rag_logprobs, axis=-1, keepdims=True) # total sum of all (normalised) logits + + ll, smooth_obj = _mask_pads(ll, smooth_obj) + + # sum over tokens, exclude bos while scoring + if exclude_bos_score and use_bos: + ll = tf.reduce_sum(ll[:, :, 1:], axis=2) + else: + ll = tf.reduce_sum(ll, axis=2) + + smooth_obj = tf.reduce_sum(smooth_obj, axis=2) + ll = tf.math.reduce_logsumexp(ll, axis=1) # logsumexp over docs + smooth_obj = tf.math.reduce_logsumexp(smooth_obj, axis=1) + + nll_loss = -ll + smooth_loss = -smooth_obj + + if reduce_loss: + nll_loss = tf.reduce_sum(nll_loss) + smooth_loss = tf.reduce_sum(smooth_loss) + + eps_i = epsilon / rag_logprobs.shape[-1] + loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss + return loss + + def generate( + self, + input_ids: Optional[tf.Tensor] = None, + attention_mask: Optional[tf.Tensor] = None, + context_input_ids=None, + context_attention_mask=None, + doc_scores=None, + do_deduplication=None, # defaults to True + num_return_sequences=None, # defaults to 1 + num_beams=None, # defaults to 1 + n_docs=None, + **model_kwargs + ): + """ + Implements RAG sequence "thorough" decoding. Read the :meth:`~transformers.PreTrainedModel.generate`` + documentation for more information on how to set other generate input parameters + + Args: + input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + The sequence used as a prompt for the generation. If :obj:`input_ids` is not passed, then + :obj:`context_input_ids` has to be provided. + attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: - 1 + for tokens that are **not masked**, - 0 for tokens that are **masked**. `What are attention masks? + <../glossary.html#attention-mask>`__ + context_input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size * config.n_docs, config.max_combined_length)`, `optional`, returned when `output_retrieved=True`): + Input IDs post-processed from the retrieved documents and the question encoder input_ids by the + retriever. + context_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size * config.n_docs, config.max_combined_length)`, `optional`, returned when `output_retrieved=True`): + Attention mask post-processed from the retrieved documents and the question encoder :obj:`input_ids` by + the retriever. If the model has is not initialized with a ``retriever`` or ``input_ids`` is not given, + :obj:`context_input_ids` and :obj:`context_attention_mask` have to be provided to the forward pass. + They are returned by :meth:`~transformers.RagRetriever.__call__`. + doc_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, config.n_docs)`): + Score between each retrieved document embeddings (see :obj:`retrieved_doc_embeds`) and + :obj:`question_encoder_last_hidden_state`. If the model has is not initialized with a ``retriever`` or + ``input_ids`` is not given, :obj:`doc_scores` has to be provided to the forward pass. :obj:`doc_scores` + are returned by :meth:`~transformers.RagRetriever.__call__`. + do_deduplication (:obj:`bool`, `optional`): + Whether or not to deduplicate the generations from different context documents for a given input. Has + to be set to :obj:`False` if used while training with distributed backend. + num_return_sequences(:obj:`int`, `optional`, defaults to 1): + The number of independently computed returned sequences for each element in the batch. Note that this + is not the value we pass to the ``generator``'s `:func:`~transformers.PreTrainedModel.generate`` + function, where we set ``num_return_sequences`` to :obj:`num_beams`. + num_beams (:obj:`int`, `optional`, defaults to 1): + Number of beams for beam search. 1 means no beam search. + n_docs (:obj:`int`, `optional`, defaults to :obj:`config.n_docs`) + Number of documents to retrieve and/or number of documents for which to generate an answer. + kwargs: + Additional kwargs will be passed to :meth:`~transformers.PreTrainedModel.generate` + + Return: + :obj:`tf.Tensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated + sequences. The second dimension (sequence length) is either equal to :obj:`max_length` or shorter if all + batches finished early due to the :obj:`eos_token_id`. + """ + + n_docs = n_docs if n_docs is not None else self.config.n_docs + do_deduplication = do_deduplication if do_deduplication is not None else self.config.do_deduplication + num_doc_return_sequences = ( + num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences + ) + num_beams = num_beams if num_beams is not None else self.config.num_beams + + assert ( + input_ids is not None or context_input_ids is not None + ), " At least one of input_ids or context_input_ids must be given" + + if self.retriever is not None and context_input_ids is None: + question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0] + context_input_ids = self.retriever( + input_ids, + question_hidden_states.numpy(), + prefix=self.generator.config.prefix, + n_docs=n_docs, + return_tensors="tf", + )["context_input_ids"] + + hypos = [] + model_kwargs["num_beams"] = num_beams + model_kwargs["num_return_sequences"] = num_beams # put here so that not confused with num_doc_return_sequences + model_kwargs["attention_mask"] = None + + batch_size = input_ids.shape[0] if input_ids is not None else context_input_ids.shape[0] // n_docs + + for index in range(batch_size): + # first, generate beams from documents: + generator_input_ids = context_input_ids[index * n_docs : (index + 1) * n_docs] # (n_docs, max_len) + + output_sequences = self.generator.generate( + generator_input_ids, + **model_kwargs, + ) # n_docs * n_beam, tgt_len + if do_deduplication: + # do_deduplication -- for TF, work on Eager mode only! + output_sequences = tf.stack(list({str(k.numpy().tolist()): k for k in output_sequences}.values())) + + num_candidates = output_sequences.shape[ + 0 + ] # after deduplication, this number can be less than n_docs*n_beam + + # then, run model forwards to get nll scores: + if input_ids is not None: + new_input_ids = tf.tile(input_ids[index : index + 1], (num_candidates, 1)) + outputs = self(new_input_ids, labels=output_sequences, exclude_bos_score=True) + else: # input_ids is None, need context_input_ids/mask and doc_scores + assert ( + context_attention_mask is not None + ), "Make sure that `context_attention_mask` are passed, if no `input_ids` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function." + assert ( + doc_scores is not None + ), "Make sure that `doc_scores` are passed, if no `input_ids` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function." + + individual_input_ids = tf.tile( + generator_input_ids, (num_candidates, 1) + ) # (num_candidates*n_docs, max_len) + + individual_attention_mask = context_attention_mask[index * n_docs : (index + 1) * n_docs] + individual_attention_mask = tf.tile(individual_attention_mask, (num_candidates, 1)) + + individual_doc_scores = doc_scores[index : (index + 1), :] # doc_scores.shape = [batch, n_docs] + individual_doc_scores = tf.tile(individual_doc_scores, (num_candidates, 1)) # [num_candidates, n_docs] + + outputs = self( + input_ids=None, + context_input_ids=individual_input_ids, + context_attention_mask=individual_attention_mask, + doc_scores=individual_doc_scores, + labels=output_sequences, + exclude_bos_score=True, + ) + + top_cand_inds = tf.math.top_k((-outputs["loss"]), k=num_doc_return_sequences)[1] + + # add hypothesis + hypos.append(tf.gather(output_sequences, top_cand_inds)) + + return self._cat_and_pad(hypos, pad_token_id=self.config.generator.pad_token_id) + + @staticmethod + def _cat_and_pad(tensors, pad_token_id): + # used by generate(): tensors is a (batched) list of (candidates, len); len is varied across batch + + # Initialize padded tensor with shape ( all_candidates , max_candidate_length ), + # where all_candidates counted from all inputs + new_shape = sum([t.shape[0] for t in tensors]), max([t.shape[1] for t in tensors]) + output = tf.fill(new_shape, pad_token_id) + + # Normal tensor doesn't support slice assignment, so we need tf.Variable + output = tf.Variable(output) + + # Assign, and then convert back to tensor + ind = 0 + for t in tensors: + output[ind : ind + t.shape[0], : t.shape[1]].assign(t) + ind += t.shape[0] + + output = tf.convert_to_tensor(output) + return tf.cast(output, tensors[0][0][0].dtype) diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 838a9293f..e6080a864 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -1332,6 +1332,25 @@ class TFPegasusModel: requires_tf(self) +class TFRagModel: + def __init__(self, *args, **kwargs): + requires_tf(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_tf(self) + + +class TFRagSequenceForGeneration: + def __init__(self, *args, **kwargs): + requires_tf(self) + + +class TFRagTokenForGeneration: + def __init__(self, *args, **kwargs): + requires_tf(self) + + TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/test_modeling_tf_rag.py b/tests/test_modeling_tf_rag.py new file mode 100644 index 000000000..ec96aee8f --- /dev/null +++ b/tests/test_modeling_tf_rag.py @@ -0,0 +1,1102 @@ +import json +import os +import shutil +import tempfile +import unittest +from unittest.mock import patch + +import numpy as np + +from transformers import BartTokenizer +from transformers.file_utils import cached_property, is_datasets_available, is_faiss_available, is_tf_available +from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES +from transformers.models.dpr.tokenization_dpr import DPRQuestionEncoderTokenizer +from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES +from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow + + +if is_tf_available() and is_datasets_available() and is_faiss_available(): + import tensorflow as tf + from datasets import Dataset + import faiss + + from transformers import ( + AutoConfig, + RagConfig, + RagRetriever, + RagTokenizer, + TFAutoModel, + TFAutoModelForSeq2SeqLM, + TFRagModel, + TFRagSequenceForGeneration, + TFRagTokenForGeneration, + ) + + from transformers.modeling_tf_outputs import TFBaseModelOutput + +from .test_modeling_tf_bart import TFBartModelTester +from .test_modeling_tf_dpr import TFDPRModelTester + + +TOLERANCE = 1e-3 + + +def require_retrieval(test_case): + """ + Decorator marking a test that requires a set of dependencies necessary for pefrorm retrieval with + :class:`~transformers.RagRetriever`. + + These tests are skipped when respective libraries are not installed. + + """ + if not (is_tf_available() and is_datasets_available() and is_faiss_available()): + test_case = unittest.skip("test requires tensorflow, datasets and faiss")(test_case) + return test_case + + +@require_tf +@require_retrieval +@require_sentencepiece +class TFRagTestMixin: + + all_model_classes = ( + (TFRagModel, TFRagTokenForGeneration, TFRagSequenceForGeneration) + if is_tf_available() and is_datasets_available() and is_faiss_available() + else () + ) + all_generative_model_classes = ( + (TFRagTokenForGeneration, TFRagSequenceForGeneration) + if is_tf_available() and is_datasets_available() and is_faiss_available() + else () + ) + + retrieval_vector_size = 32 + n_docs = 3 + max_combined_length = 16 + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + + # DPR tok + vocab_tokens = [ + "[UNK]", + "[CLS]", + "[SEP]", + "[PAD]", + "[MASK]", + "want", + "##want", + "##ed", + "wa", + "un", + "runn", + "##ing", + ",", + "low", + "lowest", + ] + dpr_tokenizer_path = os.path.join(self.tmpdirname, "dpr_tokenizer") + os.makedirs(dpr_tokenizer_path, exist_ok=True) + self.vocab_file = os.path.join(dpr_tokenizer_path, DPR_VOCAB_FILES_NAMES["vocab_file"]) + with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: + vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) + + # BART tok + vocab = [ + "l", + "o", + "w", + "e", + "r", + "s", + "t", + "i", + "d", + "n", + "\u0120", + "\u0120l", + "\u0120n", + "\u0120lo", + "\u0120low", + "er", + "\u0120lowest", + "\u0120newer", + "\u0120wider", + "", + ] + vocab_tokens = dict(zip(vocab, range(len(vocab)))) + merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""] + self.special_tokens_map = {"unk_token": ""} + + bart_tokenizer_path = os.path.join(self.tmpdirname, "bart_tokenizer") + os.makedirs(bart_tokenizer_path, exist_ok=True) + self.vocab_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["vocab_file"]) + self.merges_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["merges_file"]) + with open(self.vocab_file, "w", encoding="utf-8") as fp: + fp.write(json.dumps(vocab_tokens) + "\n") + with open(self.merges_file, "w", encoding="utf-8") as fp: + fp.write("\n".join(merges)) + + @cached_property + def dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer: + return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer")) + + @cached_property + def bart_tokenizer(self) -> BartTokenizer: + return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer")) + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def get_retriever(self, config): + dataset = Dataset.from_dict( + { + "id": ["0", "1", "3"], + "text": ["foo", "bar", "qux"], + "title": ["Foo", "Bar", "Qux"], + "embeddings": [ + np.ones(self.retrieval_vector_size), + 2 * np.ones(self.retrieval_vector_size), + 3 * np.ones(self.retrieval_vector_size), + ], + } + ) + dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT) + tokenizer = self.bart_tokenizer + with patch("transformers.models.rag.retrieval_rag.load_dataset") as mock_load_dataset: + mock_load_dataset.return_value = dataset + retriever = RagRetriever( + config, + question_encoder_tokenizer=self.dpr_tokenizer, + generator_tokenizer=tokenizer, + ) + return retriever + + def check_model_with_retriever( + self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs + ): + self.assertIsNotNone(config.question_encoder) + self.assertIsNotNone(config.generator) + + for model_class in self.all_model_classes: + model = model_class(config, retriever=self.get_retriever(config)) + + self.assertTrue(model.config.is_encoder_decoder) + + outputs = model( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + ) + + # logits + self.assertEqual( + outputs.logits.shape, + (self.n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size), + ) + # generator encoder last hidden states + self.assertEqual( + outputs.generator_enc_last_hidden_state.shape, + (self.n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size), + ) + # doc scores + self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs)) + + def check_model_generate_from_context_input_ids( + self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs + ): + self.assertIsNotNone(config.question_encoder) + self.assertIsNotNone(config.generator) + + retriever = self.get_retriever(config) + + for i, model_class in enumerate(self.all_generative_model_classes): + model = model_class(config) + self.assertTrue(model.config.is_encoder_decoder) + + question_hidden_states = model.question_encoder(input_ids, attention_mask=attention_mask)[0] + + out = retriever( + input_ids, + question_hidden_states.numpy(), + prefix=config.generator.prefix, + return_tensors="tf", + ) + + context_input_ids, context_attention_mask, retrieved_doc_embeds = ( + out["context_input_ids"], + out["context_attention_mask"], + out["retrieved_doc_embeds"], + ) + retrieved_doc_embeds = tf.cast(retrieved_doc_embeds, tf.float32) + + # compute doc_scores + doc_scores = tf.squeeze( + tf.matmul(tf.expand_dims(question_hidden_states, axis=[1]), retrieved_doc_embeds, transpose_b=True), + axis=[1], + ) + + outputs = model.generate( + context_input_ids=context_input_ids, + context_attention_mask=context_attention_mask, + doc_scores=doc_scores, + ) + + self.assertIsNotNone(outputs) + + def check_model_generate( + self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs + ): + self.assertIsNotNone(config.question_encoder) + self.assertIsNotNone(config.generator) + + for model_class in self.all_generative_model_classes: + model = model_class(config, retriever=self.get_retriever(config)) + + self.assertTrue(model.config.is_encoder_decoder) + + input_ids = tf.cast(input_ids, tf.int32) + outputs = model.generate( + input_ids=input_ids, + num_beams=2, + num_return_sequences=2, + decoder_start_token_id=config.generator.eos_token_id, + ) + + self.assertIsNotNone(outputs) + + def check_model_without_retriever( + self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs + ): + self.assertIsNotNone(config.question_encoder) + self.assertIsNotNone(config.generator) + + retriever = self.get_retriever(config) + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertTrue(model.config.is_encoder_decoder) + + question_hidden_states = model.question_encoder(input_ids, attention_mask=attention_mask)[0] + + out = retriever( + input_ids, + question_hidden_states.numpy(), + prefix=config.generator.prefix, + return_tensors="tf", + ) + + context_input_ids, context_attention_mask, retrieved_doc_embeds = ( + out["context_input_ids"], + out["context_attention_mask"], + out["retrieved_doc_embeds"], + ) + + retrieved_doc_embeds = tf.cast(retrieved_doc_embeds, tf.float32) + + # compute doc_scores + doc_scores = tf.squeeze( + tf.matmul(tf.expand_dims(question_hidden_states, axis=[1]), retrieved_doc_embeds, transpose_b=True), + axis=[1], + ) + + outputs = model( + input_ids=None, + context_input_ids=context_input_ids, + context_attention_mask=context_attention_mask, + doc_scores=doc_scores, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + ) + + # logits + self.assertEqual( + outputs.logits.shape, + (self.n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size), + ) + + # generator encoder last hidden states + self.assertEqual( + outputs.generator_enc_last_hidden_state.shape, + (self.n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size), + ) + # doc scores + self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs)) + + def check_model_custom_n_docs( + self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, n_docs, **kwargs + ): + self.assertIsNotNone(config.question_encoder) + self.assertIsNotNone(config.generator) + + retriever = self.get_retriever(config) + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertTrue(model.config.is_encoder_decoder) + + question_hidden_states = model.question_encoder(input_ids, attention_mask=attention_mask)[0] + + out = retriever( + input_ids, + question_hidden_states.numpy(), + prefix=config.generator.prefix, + return_tensors="tf", + n_docs=n_docs, + ) + + context_input_ids, context_attention_mask, retrieved_doc_embeds = ( + out["context_input_ids"], + out["context_attention_mask"], + out["retrieved_doc_embeds"], + ) + + retrieved_doc_embeds = tf.cast(retrieved_doc_embeds, tf.float32) + + # compute doc_scores + doc_scores = tf.squeeze( + tf.matmul(tf.expand_dims(question_hidden_states, axis=[1]), retrieved_doc_embeds, transpose_b=True), + axis=[1], + ) + + outputs = model( + input_ids=None, + context_input_ids=context_input_ids, + context_attention_mask=context_attention_mask, + doc_scores=doc_scores, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + n_docs=n_docs, + ) + + # logits + self.assertEqual( + outputs.logits.shape, + (n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size), + ) + # generator encoder last hidden states + self.assertEqual( + outputs.generator_enc_last_hidden_state.shape, + (n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size), + ) + # doc scores + self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], n_docs)) + + def check_model_with_mismatch_n_docs_value( + self, + config, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + retriever_n_docs, + generator_n_docs, + **kwargs + ): + self.assertIsNotNone(config.question_encoder) + self.assertIsNotNone(config.generator) + + retriever = self.get_retriever(config) + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertTrue(model.config.is_encoder_decoder) + + question_hidden_states = model.question_encoder(input_ids, attention_mask=attention_mask)[0] + + out = retriever( + input_ids, + question_hidden_states.numpy(), + prefix=config.generator.prefix, + return_tensors="tf", + n_docs=retriever_n_docs, + ) + + context_input_ids, context_attention_mask, retrieved_doc_embeds = ( + out["context_input_ids"], + out["context_attention_mask"], + out["retrieved_doc_embeds"], + ) + + retrieved_doc_embeds = tf.cast(retrieved_doc_embeds, tf.float32) + + # compute doc_scores + doc_scores = tf.squeeze( + tf.matmul(tf.expand_dims(question_hidden_states, axis=[1]), retrieved_doc_embeds, transpose_b=True), + axis=[1], + ) + + self.assertRaises( + AssertionError, + model.__call__, + input_ids=None, + context_input_ids=context_input_ids, + context_attention_mask=context_attention_mask, + doc_scores=doc_scores, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + n_docs=generator_n_docs, + ) + + def check_model_with_encoder_outputs( + self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs + ): + self.assertIsNotNone(config.question_encoder) + self.assertIsNotNone(config.generator) + + for model_class in self.all_model_classes: + model = model_class(config, retriever=self.get_retriever(config)) + + self.assertTrue(model.config.is_encoder_decoder) + + outputs = model( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + ) + + encoder_outputs = TFBaseModelOutput(outputs.generator_enc_last_hidden_state) + + # run only generator + outputs = model( + input_ids=None, + encoder_outputs=encoder_outputs, + doc_scores=outputs.doc_scores, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + ) + + # logits + self.assertEqual( + outputs.logits.shape, + (self.n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size), + ) + # generator encoder last hidden states + self.assertEqual( + outputs.generator_enc_last_hidden_state.shape, + (self.n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size), + ) + # doc scores + self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs)) + + def test_model_with_retriever(self): + inputs_dict = self.config_and_inputs + self.check_model_with_retriever(**inputs_dict) + + def test_model_without_retriever(self): + inputs_dict = self.config_and_inputs + self.check_model_without_retriever(**inputs_dict) + + def test_model_generate_from_context_input_ids(self): + inputs_dict = self.config_and_inputs + self.check_model_generate_from_context_input_ids(**inputs_dict) + + def test_model_with_encoder_outputs(self): + inputs_dict = self.config_and_inputs + self.check_model_with_encoder_outputs(**inputs_dict) + + def test_model_generate(self): + inputs_dict = self.config_and_inputs + self.check_model_generate(**inputs_dict) + + def test_model_with_custom_n_docs(self): + inputs_dict = self.config_and_inputs + inputs_dict["n_docs"] = 1 + self.check_model_custom_n_docs(**inputs_dict) + + def test_model_with_mismatch_n_docs_value(self): + inputs_dict = self.config_and_inputs + inputs_dict["retriever_n_docs"] = 3 + inputs_dict["generator_n_docs"] = 2 + self.check_model_with_mismatch_n_docs_value(**inputs_dict) + + +@require_tf +@require_retrieval +class TFRagDPRBartTest(TFRagTestMixin, unittest.TestCase): + @cached_property + def config_and_inputs(self): + question_encoder_tester = TFDPRModelTester(self) + dpr_config_and_inputs = question_encoder_tester.prepare_config_and_inputs() + generator_tester = TFBartModelTester(self) + bart_config_and_inputs = generator_tester.prepare_config_and_inputs_for_common() + + (question_encoder_config, input_ids, _, input_mask, _, _, _) = dpr_config_and_inputs + (generator_config, bart_inputs_dict) = bart_config_and_inputs + decoder_input_ids, decoder_attention_mask = bart_inputs_dict["input_ids"], bart_inputs_dict["attention_mask"] + + config = RagConfig.from_question_encoder_generator_configs( + question_encoder_config, + generator_config, + n_docs=self.n_docs, + retrieval_vector_size=self.retrieval_vector_size, + max_combined_length=self.max_combined_length, + ) + + return { + "config": config, + "input_ids": input_ids, + "attention_mask": input_mask, + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": decoder_attention_mask, + } + + +@require_tf +@require_retrieval +@require_sentencepiece +@require_tokenizers +class TFRagModelIntegrationTests(unittest.TestCase): + @cached_property + def token_model(self): + return TFRagTokenForGeneration.from_pretrained_question_encoder_generator( + "facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large-cnn" + ) + + @cached_property + def sequence_model(self): + return TFRagSequenceForGeneration.from_pretrained_question_encoder_generator( + "facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large-cnn" + ) + + def token_model_nq_checkpoint(self, retriever): + return TFRagTokenForGeneration.from_pretrained("facebook/rag-token-nq", from_pt=True, retriever=retriever) + + def get_rag_config(self): + question_encoder_config = AutoConfig.from_pretrained("facebook/dpr-question_encoder-single-nq-base") + generator_config = AutoConfig.from_pretrained("facebook/bart-large-cnn") + return RagConfig.from_question_encoder_generator_configs( + question_encoder_config, + generator_config, + bos_token_id=0, + decoder_start_token_id=2, + eos_token_id=2, + is_encoder_decoder=True, + pad_token_id=1, + vocab_size=50264, + title_sep=" / ", + doc_sep=" // ", + n_docs=5, + max_combined_length=300, + dataset="wiki_dpr", + dataset_split="train", + index_name="exact", + index_path=None, + use_dummy_dataset=True, + retrieval_vector_size=768, + retrieval_batch_size=8, + ) + + @slow + def test_rag_sequence_inference(self): + rag_config = self.get_rag_config() + rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") + rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( + "facebook/dpr-question_encoder-single-nq-base" + ) + rag_retriever = RagRetriever( + rag_config, + question_encoder_tokenizer=rag_question_encoder_tokenizer, + generator_tokenizer=rag_decoder_tokenizer, + ) + + rag_sequence = self.sequence_model + rag_sequence.set_retriever(rag_retriever) + + input_ids = rag_question_encoder_tokenizer( + "who sings does he love me with reba", return_tensors="tf" + ).input_ids + decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="tf").input_ids + + output = rag_sequence( + input_ids, + labels=decoder_input_ids, + ) + + expected_shape = tf.TensorShape([5, 5, 50264]) + self.assertEqual(output.logits.shape, expected_shape) + + expected_doc_scores = tf.convert_to_tensor([[75.0286, 74.4998, 74.0804, 74.0306, 73.9504]]) + expected_loss = tf.convert_to_tensor([36.7368]) + + tf.debugging.assert_near(output.loss, expected_loss, atol=1e-3) + tf.debugging.assert_near(output.doc_scores, expected_doc_scores, atol=1e-3) + + @slow + def test_rag_token_inference(self): + rag_config = self.get_rag_config() + rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") + rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( + "facebook/dpr-question_encoder-single-nq-base" + ) + rag_retriever = RagRetriever( + rag_config, + question_encoder_tokenizer=rag_question_encoder_tokenizer, + generator_tokenizer=rag_decoder_tokenizer, + ) + + rag_token = self.token_model + rag_token.set_retriever(rag_retriever) + + input_ids = rag_question_encoder_tokenizer( + "who sings does he love me with reba", return_tensors="tf" + ).input_ids + decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="tf").input_ids + + output = rag_token( + input_ids, + labels=decoder_input_ids, + ) + + expected_shape = tf.TensorShape([5, 5, 50264]) + self.assertEqual(output.logits.shape, expected_shape) + + expected_doc_scores = tf.convert_to_tensor([[75.0286, 74.4998, 74.0804, 74.0306, 73.9504]]) + expected_loss = tf.convert_to_tensor([36.3557]) + + tf.debugging.assert_near(output.loss, expected_loss, atol=1e-3) + tf.debugging.assert_near(output.doc_scores, expected_doc_scores, atol=1e-3) + + @slow + def test_rag_token_inference_nq_checkpoint(self): + rag_config = self.get_rag_config() + rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") + rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( + "facebook/dpr-question_encoder-single-nq-base" + ) + rag_retriever = RagRetriever( + rag_config, + question_encoder_tokenizer=rag_question_encoder_tokenizer, + generator_tokenizer=rag_decoder_tokenizer, + ) + + rag_token = self.token_model_nq_checkpoint(retriever=rag_retriever) + + # check that outputs after saving and loading are equal + with tempfile.TemporaryDirectory() as tmpdirname: + rag_token.save_pretrained(tmpdirname) + rag_token = TFRagTokenForGeneration.from_pretrained(tmpdirname, retriever=rag_retriever) + + input_ids = rag_question_encoder_tokenizer( + "who sings does he love me with reba", return_tensors="tf" + ).input_ids + decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="tf").input_ids + + output = rag_token( + input_ids, + labels=decoder_input_ids, + ) + + expected_shape = tf.TensorShape([5, 5, 50265]) + self.assertEqual(output.logits.shape, expected_shape) + + expected_doc_scores = tf.convert_to_tensor([[62.9402, 62.7107, 62.2382, 62.1194, 61.8578]]) + expected_loss = tf.convert_to_tensor([32.521812]) + + tf.debugging.assert_near(output.loss, expected_loss, atol=1e-3) + tf.debugging.assert_near(output.doc_scores, expected_doc_scores, atol=1e-3) + + @slow + def test_rag_token_inference_save_pretrained(self): + rag_config = self.get_rag_config() + rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") + rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( + "facebook/dpr-question_encoder-single-nq-base" + ) + rag_retriever = RagRetriever( + rag_config, + question_encoder_tokenizer=rag_question_encoder_tokenizer, + generator_tokenizer=rag_decoder_tokenizer, + ) + + rag_token = self.token_model + rag_token.set_retriever(rag_retriever) + + input_ids = rag_question_encoder_tokenizer( + "who sings does he love me with reba", return_tensors="tf" + ).input_ids + decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="tf").input_ids + + # model must run once to be functional before loading/saving works + rag_token( + input_ids, + labels=decoder_input_ids, + ) + + # check that outputs after saving and loading are equal + with tempfile.TemporaryDirectory() as tmpdirname: + rag_token.save_pretrained(tmpdirname) + rag_token = TFRagTokenForGeneration.from_pretrained(tmpdirname, retriever=rag_retriever) + + output = rag_token( + input_ids, + labels=decoder_input_ids, + ) + + expected_shape = tf.TensorShape([5, 5, 50264]) + self.assertEqual(output.logits.shape, expected_shape) + + expected_doc_scores = tf.convert_to_tensor([[75.0286, 74.4998, 74.0804, 74.0306, 73.9504]]) + expected_loss = tf.convert_to_tensor([36.3557]) + + tf.debugging.assert_near(output.loss, expected_loss, atol=1e-3) + tf.debugging.assert_near(output.doc_scores, expected_doc_scores, atol=1e-3) + + @slow + def test_init_and_from_pretrained(self): + rag_config = self.get_rag_config() + rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") + rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( + "facebook/dpr-question_encoder-single-nq-base" + ) + rag_retriever = RagRetriever( + rag_config, + question_encoder_tokenizer=rag_question_encoder_tokenizer, + generator_tokenizer=rag_decoder_tokenizer, + ) + + rag_config = RagConfig.from_pretrained("facebook/rag-sequence-base") + rag = TFRagTokenForGeneration(rag_config, retriever=rag_retriever) + + input_ids = rag_question_encoder_tokenizer( + "who sings does he love me with reba", return_tensors="tf" + ).input_ids + decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="tf").input_ids + + rag( + input_ids, + decoder_input_ids=decoder_input_ids, + ) + + # this should not give any warnings + with tempfile.TemporaryDirectory() as tmpdirname: + rag.save_pretrained(tmpdirname) + rag = TFRagTokenForGeneration.from_pretrained(tmpdirname, retriever=rag_retriever) + + @property + def test_data_questions(self): + return [ + "who got the first nobel prize in physics", + "when is the next deadpool movie being released", + "which mode is used for short wave broadcast service", + "who is the owner of reading football club", + "when is the next scandal episode coming out", + "when is the last time the philadelphia won the superbowl", + "what is the most current adobe flash player version", + "how many episodes are there in dragon ball z", + "what is the first step in the evolution of the eye", + "where is gall bladder situated in human body", + "what is the main mineral in lithium batteries", + "who is the president of usa right now", + "where do the greasers live in the outsiders", + "panda is a national animal of which country", + "what is the name of manchester united stadium", + ] + + @slow + def test_rag_token_greedy_search(self): + tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") + retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True) + rag_token = TFRagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever, from_pt=True) + + # check first two questions + input_dict = tokenizer( + self.test_data_questions[:2], + return_tensors="tf", + padding=True, + truncation=True, + ) + + input_ids = input_dict.input_ids + attention_mask = input_dict.attention_mask + + # make sure only 1 beam is used + rag_token.config.num_beams = 1 + + output_ids = rag_token.generate( + input_ids, + attention_mask=attention_mask, + ) + + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + + EXPECTED_OUTPUTS = [ + " albert einstein", + " september 22, 2017", + ] + self.assertListEqual(outputs, EXPECTED_OUTPUTS) + + @slow + def test_rag_token_generate_batch(self): + # NOTE: gold labels comes from num_beam=4, so this is effectively beam-search test + tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") + retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True) + rag_token = TFRagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever, from_pt=True) + + input_dict = tokenizer( + self.test_data_questions, + return_tensors="tf", + padding=True, + truncation=True, + ) + + input_ids = input_dict.input_ids + attention_mask = input_dict.attention_mask + + output_ids = rag_token.generate( + input_ids, + attention_mask=attention_mask, + ) + + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + + EXPECTED_OUTPUTS = [ + " albert einstein", + " september 22, 2017", + " amplitude modulation", + " stefan persson", + " april 20, 2018", + " the 1970s", + " 7.1. 2", + " 13", + " evolution", + " stomach", + " spodumene", + " obama", + " northern new jersey", + " india", + " united stadium", + ] + self.assertListEqual(outputs, EXPECTED_OUTPUTS) + + @slow + def test_rag_sequence_generate_batch(self): + tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") + retriever = RagRetriever.from_pretrained( + "facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True + ) + rag_sequence = TFRagSequenceForGeneration.from_pretrained( + "facebook/rag-sequence-nq", retriever=retriever, from_pt=True + ) + + input_dict = tokenizer( + self.test_data_questions, + return_tensors="tf", + padding=True, + truncation=True, + ) + + input_ids = input_dict.input_ids + attention_mask = input_dict.attention_mask + + output_ids = rag_sequence.generate( + input_ids, + attention_mask=attention_mask, + ) + + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + + EXPECTED_OUTPUTS = [ + " albert einstein", + " june 22, 2018", + " amplitude modulation", + " tim besley ( chairman )", + " june 20, 2018", + " 1980", + " 7.0", + " 8", + " reticular formation", + " walls of the abdomen", + " spodumene", + " obama", + " new orleans", + " japan", + " old trafford", + ] + self.assertListEqual(outputs, EXPECTED_OUTPUTS) + + @slow + def test_rag_sequence_generate_batch_from_context_input_ids(self): + tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") + retriever = RagRetriever.from_pretrained( + "facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True + ) + rag_sequence = TFRagSequenceForGeneration.from_pretrained( + "facebook/rag-sequence-nq", retriever=retriever, from_pt=True + ) + input_dict = tokenizer( + self.test_data_questions, + return_tensors="tf", + padding=True, + truncation=True, + ) + + input_ids = input_dict.input_ids + + question_hidden_states = rag_sequence.question_encoder(input_ids)[0] + docs_dict = retriever(input_ids.numpy(), question_hidden_states.numpy(), return_tensors="tf") + doc_scores = tf.squeeze( + tf.matmul( + tf.expand_dims(question_hidden_states, axis=[1]), docs_dict["retrieved_doc_embeds"], transpose_b=True + ), + axis=[1], + ) + output_ids = rag_sequence.generate( + context_input_ids=docs_dict["context_input_ids"], + context_attention_mask=docs_dict["context_attention_mask"], + doc_scores=doc_scores, + do_deduplication=True, + ) + + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + + EXPECTED_OUTPUTS = [ + " albert einstein", + " june 22, 2018", + " amplitude modulation", + " tim besley ( chairman )", + " june 20, 2018", + " 1980", + " 7.0", + " 8", + " reticular formation", + " walls of the abdomen", + " spodumene", + " obama", + " new orleans", + " japan", + " old trafford", + ] + self.assertListEqual(outputs, EXPECTED_OUTPUTS) + + +@require_tf +@require_retrieval +class TFRagModelSaveLoadTests(unittest.TestCase): + def get_rag_config(self): + question_encoder_config = AutoConfig.from_pretrained("facebook/dpr-question_encoder-single-nq-base") + generator_config = AutoConfig.from_pretrained("facebook/bart-large-cnn") + return RagConfig.from_question_encoder_generator_configs( + question_encoder_config, + generator_config, + bos_token_id=0, + decoder_start_token_id=2, + eos_token_id=2, + is_encoder_decoder=True, + pad_token_id=1, + vocab_size=50264, + title_sep=" / ", + doc_sep=" // ", + n_docs=5, + max_combined_length=300, + dataset="wiki_dpr", + dataset_split="train", + index_name="exact", + index_path=None, + use_dummy_dataset=True, + retrieval_vector_size=768, + retrieval_batch_size=8, + ) + + @slow + def test_rag_sequence_from_pretrained(self): + load_weight_prefix = "tf_rag_model_1" + + rag_config = self.get_rag_config() + rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") + rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( + "facebook/dpr-question_encoder-single-nq-base" + ) + rag_retriever = RagRetriever( + rag_config, + question_encoder_tokenizer=rag_question_encoder_tokenizer, + generator_tokenizer=rag_decoder_tokenizer, + ) + + input_ids = rag_question_encoder_tokenizer( + "who sings does he love me with reba", return_tensors="tf" + ).input_ids + decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="tf").input_ids + + with tempfile.TemporaryDirectory() as tmp_dirname: + rag_sequence = TFRagSequenceForGeneration.from_pretrained_question_encoder_generator( + "facebook/dpr-question_encoder-single-nq-base", + "facebook/bart-large-cnn", + retriever=rag_retriever, + config=rag_config, + ) + # check that the from pretrained methods work + rag_sequence.save_pretrained(tmp_dirname) + rag_sequence.from_pretrained(tmp_dirname, retriever=rag_retriever) + + output = rag_sequence(input_ids, labels=decoder_input_ids) + + loss_pretrained = output.loss + del rag_sequence + + question_encoder = TFAutoModel.from_pretrained("facebook/dpr-question_encoder-single-nq-base") + generator = TFAutoModelForSeq2SeqLM.from_pretrained( + "facebook/bart-large-cnn", load_weight_prefix=load_weight_prefix, name="generator" + ) + + rag_sequence = TFRagSequenceForGeneration( + config=rag_config, question_encoder=question_encoder, generator=generator, retriever=rag_retriever + ) + + output = rag_sequence(input_ids, labels=decoder_input_ids) + + loss_init = output.loss + + self.assertAlmostEqual(loss_pretrained, loss_init, places=4) + + @slow + def test_rag_token_from_pretrained(self): + load_weight_prefix = "tf_rag_model_1" + + rag_config = self.get_rag_config() + rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") + rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( + "facebook/dpr-question_encoder-single-nq-base" + ) + rag_retriever = RagRetriever( + rag_config, + question_encoder_tokenizer=rag_question_encoder_tokenizer, + generator_tokenizer=rag_decoder_tokenizer, + ) + + input_ids = rag_question_encoder_tokenizer( + "who sings does he love me with reba", return_tensors="tf" + ).input_ids + decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="tf").input_ids + + with tempfile.TemporaryDirectory() as tmp_dirname: + rag_token = TFRagTokenForGeneration.from_pretrained_question_encoder_generator( + "facebook/dpr-question_encoder-single-nq-base", + "facebook/bart-large-cnn", + retriever=rag_retriever, + config=rag_config, + ) + # check that the from pretrained methods work + rag_token.save_pretrained(tmp_dirname) + rag_token.from_pretrained(tmp_dirname, retriever=rag_retriever) + + output = rag_token(input_ids, labels=decoder_input_ids) + + loss_pretrained = output.loss + del rag_token + + question_encoder = TFAutoModel.from_pretrained("facebook/dpr-question_encoder-single-nq-base") + generator = TFAutoModelForSeq2SeqLM.from_pretrained( + "facebook/bart-large-cnn", load_weight_prefix=load_weight_prefix, name="generator" + ) + rag_token = TFRagTokenForGeneration( + config=rag_config, question_encoder=question_encoder, generator=generator, retriever=rag_retriever + ) + + output = rag_token(input_ids, labels=decoder_input_ids) + + loss_init = output.loss + + self.assertAlmostEqual(loss_pretrained, loss_init, places=4) diff --git a/utils/check_repo.py b/utils/check_repo.py index 0db06c9e7..afcc4cbd7 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -121,6 +121,9 @@ IGNORE_NON_AUTO_CONFIGURED = [ "TFGPT2DoubleHeadsModel", "TFMT5EncoderModel", "TFOpenAIGPTDoubleHeadsModel", + "TFRagModel", + "TFRagSequenceForGeneration", + "TFRagTokenForGeneration", "TFT5EncoderModel", "Wav2Vec2ForCTC", "XLMForQuestionAnswering",