diff --git a/src/transformers/modeling_rag.py b/src/transformers/modeling_rag.py index 227cf13c0..d09df68ee 100644 --- a/src/transformers/modeling_rag.py +++ b/src/transformers/modeling_rag.py @@ -458,6 +458,8 @@ RAG_FORWARD_INPUTS_DOCSTRING = r""" output_retrieved(:obj:`bool`, `optional`): Whether or not to return the :obj:`retrieved_doc_embeds`, :obj:`retrieved_doc_ids`, :obj:`context_input_ids` and :obj:`context_attention_mask`. See returned tensors for more detail. + n_docs (:obj:`int`, `optional`, defaults to :obj:`config.n_docs`) + Number of documents to retrieve and/or number of documents for which to generate an answer. """ @@ -521,6 +523,7 @@ class RagModel(RagPreTrainedModel): output_attentions=None, output_hidden_states=None, output_retrieved=None, + n_docs=None, ): r""" Returns: @@ -540,6 +543,7 @@ class RagModel(RagPreTrainedModel): >>> outputs = model(input_ids=input_ids) """ + n_docs = n_docs if n_docs is not None else self.config.n_docs use_cache = use_cache if use_cache is not None else self.config.use_cache output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -566,7 +570,7 @@ class RagModel(RagPreTrainedModel): input_ids, question_encoder_last_hidden_state.cpu().detach().to(torch.float32).numpy(), prefix=self.generator.config.prefix, - n_docs=self.config.n_docs, + n_docs=n_docs, return_tensors="pt", ) context_input_ids, context_attention_mask, retrieved_doc_embeds, retrieved_doc_ids = ( @@ -600,12 +604,16 @@ class RagModel(RagPreTrainedModel): doc_scores is not None ), "Make sure that `doc_scores` are passed when passing `encoder_outputs` to the forward function." + assert ( + doc_scores.shape[1] % n_docs + ) == 0, f" The first dimension of `context_input_ids` should be a multiple of `n_docs`={n_docs}, but is {context_input_ids.shape[0]}." + # Decoder input without context documents if decoder_input_ids is not None: - decoder_input_ids = decoder_input_ids.repeat_interleave(self.config.n_docs, dim=0) + decoder_input_ids = decoder_input_ids.repeat_interleave(n_docs, dim=0) if decoder_attention_mask is not None: - decoder_attention_mask = decoder_attention_mask.repeat_interleave(self.config.n_docs, dim=0) + decoder_attention_mask = decoder_attention_mask.repeat_interleave(n_docs, dim=0) gen_outputs = self.generator( input_ids=context_input_ids, @@ -702,6 +710,7 @@ class RagSequenceForGeneration(RagPreTrainedModel): exclude_bos_score=None, reduce_loss=None, labels=None, + n_docs=None, **kwargs # needs kwargs for generation ): r""" @@ -741,6 +750,7 @@ class RagSequenceForGeneration(RagPreTrainedModel): >>> # 3. Forward to generator >>> outputs = model(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, decoder_input_ids=input_dict["labels"]) """ + n_docs = n_docs if n_docs is not None else self.config.n_docs exclude_bos_score = exclude_bos_score if exclude_bos_score is not None else self.config.exclude_bos_score reduce_loss = reduce_loss if reduce_loss is not None else self.config.reduce_loss @@ -763,6 +773,7 @@ class RagSequenceForGeneration(RagPreTrainedModel): output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_retrieved=output_retrieved, + n_docs=n_docs, ) loss = None @@ -774,6 +785,7 @@ class RagSequenceForGeneration(RagPreTrainedModel): reduce_loss=reduce_loss, epsilon=self.config.label_smoothing, exclude_bos_score=exclude_bos_score, + n_docs=n_docs, ) return RetrievAugLMMarginOutput( @@ -816,6 +828,7 @@ class RagSequenceForGeneration(RagPreTrainedModel): do_deduplication=None, # defaults to True num_return_sequences=None, # defaults to 1 num_beams=None, # defaults to 1 + n_docs=None, **kwargs ): """ @@ -847,6 +860,8 @@ class RagSequenceForGeneration(RagPreTrainedModel): function, where we set ``num_return_sequences`` to :obj:`num_beams`. num_beams (:obj:`int`, `optional`, defaults to 1): Number of beams for beam search. 1 means no beam search. + n_docs (:obj:`int`, `optional`, defaults to :obj:`config.n_docs`) + Number of documents to retrieve and/or number of documents for which to generate an answer. kwargs: Additional kwargs will be passed to :meth:`~transformers.PreTrainedModel.generate`. @@ -856,6 +871,7 @@ class RagSequenceForGeneration(RagPreTrainedModel): shorter if all batches finished early due to the :obj:`eos_token_id`. """ + n_docs = n_docs if n_docs is not None else self.config.n_docs do_deduplication = do_deduplication if do_deduplication is not None else self.config.do_deduplication num_doc_return_sequences = ( num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences @@ -869,7 +885,7 @@ class RagSequenceForGeneration(RagPreTrainedModel): input_ids, question_hidden_states.cpu().detach().to(torch.float32).numpy(), prefix=self.generator.config.prefix, - n_docs=self.config.n_docs, + n_docs=n_docs, return_tensors="pt", )["context_input_ids"] @@ -880,12 +896,11 @@ class RagSequenceForGeneration(RagPreTrainedModel): kwargs["num_beams"] = num_beams kwargs["num_return_sequences"] = num_beams kwargs["attention_mask"] = None + kwargs["n_docs"] = n_docs for index in range(len(input_ids)): # first, generate beams from documents: - generator_input_ids = context_input_ids[ - index * self.config.n_docs : (index + 1) * self.config.n_docs - ] # (n_docs, max_len) + generator_input_ids = context_input_ids[index * n_docs : (index + 1) * n_docs] # (n_docs, max_len) output_sequences = self.generator.generate( generator_input_ids, @@ -905,12 +920,16 @@ class RagSequenceForGeneration(RagPreTrainedModel): return self._cat_and_pad(hypos, pad_token_id=self.config.generator.pad_token_id) - def get_nll(self, seq_logits, doc_scores, target, reduce_loss=False, epsilon=0.0, exclude_bos_score=False): + def get_nll( + self, seq_logits, doc_scores, target, reduce_loss=False, epsilon=0.0, exclude_bos_score=False, n_docs=None + ): # shift tokens left target = torch.cat( [target[:, 1:], target.new(target.shape[0], 1).fill_(self.config.generator.pad_token_id)], 1 ) + n_docs = n_docs if n_docs is not None else self.config.n_docs + # bos_token_id is None for T5 bos_token_id = self.config.bos_token_id or self.config.generator.bos_token_id use_bos = bos_token_id is not None and target[:, 0].eq(bos_token_id).all() @@ -923,7 +942,7 @@ class RagSequenceForGeneration(RagPreTrainedModel): return ll.squeeze(-1), smooth_obj.squeeze(-1) seq_logprobs = torch.nn.functional.log_softmax(seq_logits, dim=-1).view( - seq_logits.shape[0] // self.config.n_docs, self.config.n_docs, -1, seq_logits.size(-1) + seq_logits.shape[0] // n_docs, n_docs, -1, seq_logits.size(-1) ) # batch_size x n_docs x tgt_len x dim doc_logprobs = torch.nn.functional.log_softmax(doc_scores, dim=1).unsqueeze(-1).unsqueeze(-1) @@ -934,7 +953,7 @@ class RagSequenceForGeneration(RagPreTrainedModel): rag_logprobs = torch.cat([first_token_scores, second_token_scores + doc_logprobs, remainder], dim=2) # calcualate loss - target = target.unsqueeze(1).unsqueeze(-1).repeat(1, self.config.n_docs, 1, 1) + target = target.unsqueeze(1).unsqueeze(-1).repeat(1, n_docs, 1, 1) assert target.dim() == rag_logprobs.dim() ll = rag_logprobs.gather(dim=-1, index=target) @@ -1004,7 +1023,7 @@ class RagTokenForGeneration(RagPreTrainedModel): return self.rag.generator.adjust_logits_during_generation(logits, cur_len=cur_len, max_length=max_length) def prepare_inputs_for_generation( - self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, doc_scores, **kwargs + self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, doc_scores, n_docs=None, **kwargs ): return { "input_ids": None, @@ -1015,6 +1034,7 @@ class RagTokenForGeneration(RagPreTrainedModel): "past_key_values": past, "use_cache": use_cache, "do_marginalize": True, + "n_docs": n_docs, } @property @@ -1053,10 +1073,13 @@ class RagTokenForGeneration(RagPreTrainedModel): return reordered_past - def marginalize(self, seq_logits, doc_scores): + def marginalize(self, seq_logits, doc_scores, n_docs=None): + + n_docs = n_docs if n_docs is not None else self.config.n_docs + # RAG-token marginalization seq_logprobs = torch.nn.functional.log_softmax(seq_logits, dim=-1).view( - seq_logits.shape[0] // self.config.n_docs, self.config.n_docs, -1, seq_logits.size(-1) + seq_logits.shape[0] // n_docs, n_docs, -1, seq_logits.size(-1) ) doc_logprobs = torch.log_softmax(doc_scores, dim=1) log_prob_sum = seq_logprobs + doc_logprobs.unsqueeze(-1).unsqueeze(-1) @@ -1082,6 +1105,7 @@ class RagTokenForGeneration(RagPreTrainedModel): do_marginalize=None, reduce_loss=None, labels=None, + n_docs=None, **kwargs # needs kwargs for generation ): r""" @@ -1124,6 +1148,7 @@ class RagTokenForGeneration(RagPreTrainedModel): >>> generated = model.generate(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores) >>> generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True) """ + n_docs = n_docs if n_docs is not None else self.config.n_docs do_marginalize = do_marginalize if do_marginalize is not None else self.config.do_marginalize reduce_loss = reduce_loss if reduce_loss is not None else self.config.reduce_loss @@ -1146,6 +1171,7 @@ class RagTokenForGeneration(RagPreTrainedModel): output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_retrieved=output_retrieved, + n_docs=n_docs, ) loss = None @@ -1158,10 +1184,11 @@ class RagTokenForGeneration(RagPreTrainedModel): labels, reduce_loss=reduce_loss, epsilon=self.config.label_smoothing, + n_docs=n_docs, ) if do_marginalize: - logits = self.marginalize(logits, outputs.doc_scores) + logits = self.marginalize(logits, outputs.doc_scores, n_docs) return RetrievAugLMMarginOutput( loss=loss, @@ -1203,6 +1230,7 @@ class RagTokenForGeneration(RagPreTrainedModel): bad_words_ids=None, num_return_sequences=None, decoder_start_token_id=None, + n_docs=None, **kwargs ): """ @@ -1274,6 +1302,8 @@ class RagTokenForGeneration(RagPreTrainedModel): function, where we set ``num_return_sequences`` to :obj:`num_beams`. decoder_start_token_id (:obj:`int`, `optional`): If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token. + n_docs (:obj:`int`, `optional`, defaults to :obj:`config.n_docs`) + Number of documents to retrieve and/or number of documents for which to generate an answer. Return: :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: @@ -1281,6 +1311,7 @@ class RagTokenForGeneration(RagPreTrainedModel): shorter if all batches finished early due to the :obj:`eos_token_id`. """ # set default parameters + n_docs = n_docs if n_docs is not None else self.config.n_docs max_length = max_length if max_length is not None else self.config.max_length min_length = min_length if min_length is not None else self.config.min_length early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping @@ -1310,7 +1341,7 @@ class RagTokenForGeneration(RagPreTrainedModel): input_ids, question_hidden_states.cpu().detach().to(torch.float32).numpy(), prefix=self.generator.config.prefix, - n_docs=self.config.n_docs, + n_docs=n_docs, return_tensors="pt", ) context_input_ids, context_attention_mask, retrieved_doc_embeds = ( @@ -1329,8 +1360,12 @@ class RagTokenForGeneration(RagPreTrainedModel): 1 ) + assert ( + context_input_ids.shape[0] % n_docs + ) == 0, f" The first dimension of `context_input_ids` should be a multiple of `n_docs`={n_docs}, but is {context_input_ids.shape[0]}." + # batch_size - batch_size = context_input_ids.shape[0] // self.config.n_docs + batch_size = context_input_ids.shape[0] // n_docs encoder = self.rag.generator.get_encoder() encoder_outputs = encoder(input_ids=context_input_ids, attention_mask=context_attention_mask, return_dict=True) @@ -1345,11 +1380,11 @@ class RagTokenForGeneration(RagPreTrainedModel): def extend_enc_output(tensor, num_beams=None): # split into `batch_size`, `num_beams`, `num_docs` - tensor = tensor[None, None, :].reshape((batch_size, 1, self.config.n_docs) + tensor.shape[1:]) + tensor = tensor[None, None, :].reshape((batch_size, 1, n_docs) + tensor.shape[1:]) # repeat same last hidden states over `num_beams` dimension - tensor = tensor.expand((batch_size, num_beams, self.config.n_docs) + tensor.shape[3:]) + tensor = tensor.expand((batch_size, num_beams, n_docs) + tensor.shape[3:]) # merge `batch_size`, `num_beams`, `num_docs` dims again - return tensor.reshape((batch_size * num_beams * self.config.n_docs,) + tensor.shape[3:]) + return tensor.reshape((batch_size * num_beams * n_docs,) + tensor.shape[3:]) # correctly extend last_hidden_state and attention mask context_attention_mask = extend_enc_output(context_attention_mask, num_beams=num_beams) @@ -1362,6 +1397,7 @@ class RagTokenForGeneration(RagPreTrainedModel): vocab_size = self.config.generator.vocab_size kwargs["doc_scores"] = doc_scores kwargs["encoder_outputs"] = encoder_outputs + kwargs["n_docs"] = n_docs # not needed. TODO(PVP): change after generate refactor do_sample = False @@ -1431,7 +1467,8 @@ class RagTokenForGeneration(RagPreTrainedModel): shifted_input_ids[:, 0] = start_token_id return shifted_input_ids - def get_nll(self, seq_logits, doc_scores, target, reduce_loss=False, epsilon=0.0): + def get_nll(self, seq_logits, doc_scores, target, reduce_loss=False, epsilon=0.0, n_docs=None): + n_docs = n_docs if n_docs is not None else self.config.n_docs # shift tokens left target = torch.cat( [target[:, 1:], target.new(target.shape[0], 1).fill_(self.config.generator.pad_token_id)], 1 @@ -1444,7 +1481,7 @@ class RagTokenForGeneration(RagPreTrainedModel): smooth_obj.masked_fill_(pad_mask, 0.0) return ll.squeeze(-1), smooth_obj.squeeze(-1) - rag_logprobs = self.marginalize(seq_logits, doc_scores) + rag_logprobs = self.marginalize(seq_logits, doc_scores, n_docs) target = target.unsqueeze(-1) assert target.dim() == rag_logprobs.dim() diff --git a/tests/test_modeling_rag.py b/tests/test_modeling_rag.py index b4dfea9b8..671766173 100644 --- a/tests/test_modeling_rag.py +++ b/tests/test_modeling_rag.py @@ -82,7 +82,7 @@ def require_retrieval(test_case): """ if not (is_torch_available() and is_datasets_available() and is_faiss_available()): - test_case = unittest.skip("test requires PyTorch")(test_case) + test_case = unittest.skip("test requires PyTorch, datasets and faiss")(test_case) return test_case @@ -98,7 +98,7 @@ class RagTestMixin: ) retrieval_vector_size = 32 - n_docs = 2 + n_docs = 3 max_combined_length = 16 def setUp(self): @@ -186,10 +186,14 @@ class RagTestMixin: def get_retriever(self, config): dataset = Dataset.from_dict( { - "id": ["0", "1"], - "text": ["foo", "bar"], - "title": ["Foo", "Bar"], - "embeddings": [np.ones(self.retrieval_vector_size), 2 * np.ones(self.retrieval_vector_size)], + "id": ["0", "1", "3"], + "text": ["foo", "bar", "qux"], + "title": ["Foo", "Bar", "Qux"], + "embeddings": [ + np.ones(self.retrieval_vector_size), + 2 * np.ones(self.retrieval_vector_size), + 3 * np.ones(self.retrieval_vector_size), + ], } ) dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT) @@ -315,6 +319,125 @@ class RagTestMixin: # doc scores self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs)) + def check_model_custom_n_docs( + self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, n_docs, **kwargs + ): + self.assertIsNotNone(config.question_encoder) + self.assertIsNotNone(config.generator) + + retriever = self.get_retriever(config) + + for model_class in self.all_model_classes: + model = model_class(config).to(torch_device) + model.eval() + self.assertTrue(model.config.is_encoder_decoder) + + question_hidden_states = model.question_encoder(input_ids, attention_mask=attention_mask)[0] + + out = retriever( + input_ids, + question_hidden_states.cpu().detach().to(torch.float32).numpy(), + prefix=config.generator.prefix, + return_tensors="pt", + n_docs=n_docs, + ) + + context_input_ids, context_attention_mask, retrieved_doc_embeds = ( + out["context_input_ids"], + out["context_attention_mask"], + out["retrieved_doc_embeds"], + ) + + # cast + retrieved_doc_embeds = retrieved_doc_embeds.to(question_hidden_states) + context_input_ids = context_input_ids.to(input_ids) + context_attention_mask = context_attention_mask.to(input_ids) + + # compute doc_scores + doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), retrieved_doc_embeds.transpose(1, 2)).squeeze( + 1 + ) + + outputs = model( + context_input_ids=context_input_ids, + context_attention_mask=context_attention_mask, + doc_scores=doc_scores, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + n_docs=n_docs, + ) + + # logits + self.assertEqual( + outputs.logits.shape, + (n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size), + ) + # generator encoder last hidden states + self.assertEqual( + outputs.generator_enc_last_hidden_state.shape, + (n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size), + ) + # doc scores + self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], n_docs)) + + def check_model_with_mismatch_n_docs_value( + self, + config, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + retriever_n_docs, + generator_n_docs, + **kwargs + ): + self.assertIsNotNone(config.question_encoder) + self.assertIsNotNone(config.generator) + + retriever = self.get_retriever(config) + + for model_class in self.all_model_classes: + model = model_class(config).to(torch_device) + model.eval() + self.assertTrue(model.config.is_encoder_decoder) + + question_hidden_states = model.question_encoder(input_ids, attention_mask=attention_mask)[0] + + out = retriever( + input_ids, + question_hidden_states.cpu().detach().to(torch.float32).numpy(), + prefix=config.generator.prefix, + return_tensors="pt", + n_docs=retriever_n_docs, + ) + + context_input_ids, context_attention_mask, retrieved_doc_embeds = ( + out["context_input_ids"], + out["context_attention_mask"], + out["retrieved_doc_embeds"], + ) + + # cast + retrieved_doc_embeds = retrieved_doc_embeds.to(question_hidden_states) + context_input_ids = context_input_ids.to(input_ids) + context_attention_mask = context_attention_mask.to(input_ids) + + # compute doc_scores + doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), retrieved_doc_embeds.transpose(1, 2)).squeeze( + 1 + ) + + self.assertRaises( + AssertionError, + model.__call__, + context_input_ids=context_input_ids, + context_attention_mask=context_attention_mask, + doc_scores=doc_scores, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + n_docs=generator_n_docs, + ) + def check_model_with_encoder_outputs( self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs ): @@ -373,6 +496,17 @@ class RagTestMixin: inputs_dict = self.config_and_inputs self.check_model_generate(**inputs_dict) + def test_model_with_custom_n_docs(self): + inputs_dict = self.config_and_inputs + inputs_dict["n_docs"] = 1 + self.check_model_custom_n_docs(**inputs_dict) + + def test_model_with_mismatch_n_docs_value(self): + inputs_dict = self.config_and_inputs + inputs_dict["retriever_n_docs"] = 3 + inputs_dict["generator_n_docs"] = 2 + self.check_model_with_mismatch_n_docs_value(**inputs_dict) + @require_torch @require_retrieval