mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Allow Custom Dataset in RAG Retriever (#7763)
* add CustomHFIndex * typo in config * update tests * add custom dataset example * clean script * update test data * minor in test * docs * docs * style * fix imports * allow to pass the indexed dataset directly * update tests * use multiset DPR * address thom and patrick's comments * style * update dpr tokenizer * add output_dir flag in use_own_knowledge_dataset.py * allow custom datasets in examples/rag/finetune.py * add test for custom dataset in distributed rag retriever
This commit is contained in:
parent
a09fe140c1
commit
033f29c625
13 changed files with 663 additions and 98 deletions
|
|
@ -62,8 +62,7 @@ Rag specific outputs
|
|||
.. autoclass:: transformers.modeling_rag.RetrievAugLMOutput
|
||||
:members:
|
||||
|
||||
|
||||
RAGRetriever
|
||||
RagRetriever
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.RagRetriever
|
||||
|
|
|
|||
|
|
@ -27,13 +27,18 @@ class RagPyTorchDistributedRetriever(RagRetriever):
|
|||
It is used to decode the question and then use the generator_tokenizer.
|
||||
generator_tokenizer (:class:`~transformers.PretrainedTokenizer`):
|
||||
The tokenizer used for the generator part of the RagModel.
|
||||
index (:class:`~transformers.retrieval_rag.Index`, optional, defaults to the one defined by the configuration):
|
||||
If specified, use this index instead of the one built using the configuration
|
||||
"""
|
||||
|
||||
_init_retrieval = False
|
||||
|
||||
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer):
|
||||
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None):
|
||||
super().__init__(
|
||||
config, question_encoder_tokenizer=question_encoder_tokenizer, generator_tokenizer=generator_tokenizer
|
||||
config,
|
||||
question_encoder_tokenizer=question_encoder_tokenizer,
|
||||
generator_tokenizer=generator_tokenizer,
|
||||
index=index,
|
||||
)
|
||||
|
||||
self.process_group = None
|
||||
|
|
|
|||
|
|
@ -90,6 +90,11 @@ class GenerativeQAModule(BaseTransformer):
|
|||
config_class = RagConfig if self.is_rag_model else AutoConfig
|
||||
config = config_class.from_pretrained(hparams.model_name_or_path)
|
||||
|
||||
# set retriever parameters
|
||||
config.index_name = args.index_name or config.index_name
|
||||
config.passages_path = args.passages_path or config.passages_path
|
||||
config.index_path = args.index_path or config.index_path
|
||||
|
||||
# set extra_model_params for generator configs and load_model
|
||||
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "attention_dropout", "dropout")
|
||||
if self.is_rag_model:
|
||||
|
|
@ -97,7 +102,7 @@ class GenerativeQAModule(BaseTransformer):
|
|||
config.generator.prefix = args.prefix
|
||||
config.label_smoothing = hparams.label_smoothing
|
||||
hparams, config.generator = set_extra_model_params(extra_model_params, hparams, config.generator)
|
||||
retriever = RagPyTorchDistributedRetriever.from_pretrained(hparams.model_name_or_path)
|
||||
retriever = RagPyTorchDistributedRetriever.from_pretrained(hparams.model_name_or_path, config=config)
|
||||
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config, retriever=retriever)
|
||||
prefix = config.question_encoder.prefix
|
||||
else:
|
||||
|
|
@ -405,6 +410,28 @@ class GenerativeQAModule(BaseTransformer):
|
|||
)
|
||||
return parser
|
||||
|
||||
@staticmethod
|
||||
def add_retriever_specific_args(parser):
|
||||
parser.add_argument(
|
||||
"--index_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Name of the index to use: 'hf' for a canonical dataset from the datasets library (default), 'custom' for a local index, or 'legacy' for the orignal one)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--passages_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the dataset of passages for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the faiss index for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main(args, model=None) -> GenerativeQAModule:
|
||||
Path(args.output_dir).mkdir(exist_ok=True)
|
||||
|
|
@ -465,6 +492,7 @@ if __name__ == "__main__":
|
|||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
|
||||
parser = GenerativeQAModule.add_retriever_specific_args(parser)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
|
|||
2
examples/rag/test_data/my_knowledge_dataset.csv
Normal file
2
examples/rag/test_data/my_knowledge_dataset.csv
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
Aaron Aaron Aaron ( or ; "Ahärôn") is a prophet, high priest, and the brother of Moses in the Abrahamic religions. Knowledge of Aaron, along with his brother Moses, comes exclusively from religious texts, such as the Bible and Quran. The Hebrew Bible relates that, unlike Moses, who grew up in the Egyptian royal court, Aaron and his elder sister Miriam remained with their kinsmen in the eastern border-land of Egypt (Goshen). When Moses first confronted the Egyptian king about the Israelites, Aaron served as his brother's spokesman ("prophet") to the Pharaoh. Part of the Law (Torah) that Moses received from God at Sinai granted Aaron the priesthood for himself and his male descendants, and he became the first High Priest of the Israelites. Aaron died before the Israelites crossed the North Jordan river and he was buried on Mount Hor (Numbers 33:39; Deuteronomy 10:6 says he died and was buried at Moserah). Aaron is also mentioned in the New Testament of the Bible. According to the Book of Exodus, Aaron first functioned as Moses' assistant. Because Moses complained that he could not speak well, God appointed Aaron as Moses' "prophet" (Exodus 4:10-17; 7:1). At the command of Moses, he let his rod turn into a snake. Then he stretched out his rod in order to bring on the first three plagues. After that, Moses tended to act and speak for himself. During the journey in the wilderness, Aaron was not always prominent or active. At the battle with Amalek, he was chosen with Hur to support the hand of Moses that held the "rod of God". When the revelation was given to Moses at biblical Mount Sinai, he headed the elders of Israel who accompanied Moses on the way to the summit.
|
||||
"Pokémon" Pokémon , also known as in Japan, is a media franchise managed by The Pokémon Company, a Japanese consortium between Nintendo, Game Freak, and Creatures. The franchise copyright is shared by all three companies, but Nintendo is the sole owner of the trademark. The franchise was created by Satoshi Tajiri in 1995, and is centered on fictional creatures called "Pokémon", which humans, known as Pokémon Trainers, catch and train to battle each other for sport. The English slogan for the franchise is "Gotta Catch 'Em All". Works within the franchise are set in the Pokémon universe. The franchise began as "Pokémon Red" and "Green" (released outside of Japan as "Pokémon Red" and "Blue"), a pair of video games for the original Game Boy that were developed by Game Freak and published by Nintendo in February 1996. "Pokémon" has since gone on to become the highest-grossing media franchise of all time, with over in revenue up until March 2017. The original video game series is the second best-selling video game franchise (behind Nintendo's "Mario" franchise) with more than 300million copies sold and over 800million mobile downloads. In addition, the "Pokémon" franchise includes the world's top-selling toy brand, the top-selling trading card game with over 25.7billion cards sold, an anime television series that has become the most successful video game adaptation with over 20 seasons and 1,000 episodes in 124 countries, as well as an anime film series, a , books, manga comics, music, and merchandise. The franchise is also represented in other Nintendo media, such as the "Super Smash Bros." series. In November 2005, 4Kids Entertainment, which had managed the non-game related licensing of "Pokémon", announced that it had agreed not to renew the "Pokémon" representation agreement. The Pokémon Company International oversees all "Pokémon" licensing outside Asia.
|
||||
|
Can't render this file because it contains an unexpected character in line 1 and column 35.
|
|
|
@ -15,6 +15,7 @@ from transformers.configuration_bart import BartConfig
|
|||
from transformers.configuration_dpr import DPRConfig
|
||||
from transformers.configuration_rag import RagConfig
|
||||
from transformers.file_utils import is_datasets_available, is_faiss_available, is_psutil_available, is_torch_available
|
||||
from transformers.retrieval_rag import CustomHFIndex
|
||||
from transformers.tokenization_bart import BartTokenizer
|
||||
from transformers.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
|
||||
from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer
|
||||
|
|
@ -114,7 +115,7 @@ class RagRetrieverTest(TestCase):
|
|||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def get_dummy_pytorch_distributed_retriever(self, init_retrieval, port=12345) -> RagPyTorchDistributedRetriever:
|
||||
def get_dummy_dataset(self):
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"id": ["0", "1"],
|
||||
|
|
@ -124,6 +125,12 @@ class RagRetrieverTest(TestCase):
|
|||
}
|
||||
)
|
||||
dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)
|
||||
return dataset
|
||||
|
||||
def get_dummy_pytorch_distributed_retriever(
|
||||
self, init_retrieval: bool, port=12345
|
||||
) -> RagPyTorchDistributedRetriever:
|
||||
dataset = self.get_dummy_dataset()
|
||||
config = RagConfig(
|
||||
retrieval_vector_size=self.retrieval_vector_size,
|
||||
question_encoder=DPRConfig().to_dict(),
|
||||
|
|
@ -140,6 +147,37 @@ class RagRetrieverTest(TestCase):
|
|||
retriever.init_retrieval(port)
|
||||
return retriever
|
||||
|
||||
def get_dummy_custom_hf_index_retriever(self, init_retrieval: bool, from_disk: bool, port=12345):
|
||||
dataset = self.get_dummy_dataset()
|
||||
config = RagConfig(
|
||||
retrieval_vector_size=self.retrieval_vector_size,
|
||||
question_encoder=DPRConfig().to_dict(),
|
||||
generator=BartConfig().to_dict(),
|
||||
index_name="custom",
|
||||
)
|
||||
if from_disk:
|
||||
config.passages_path = os.path.join(self.tmpdirname, "dataset")
|
||||
config.index_path = os.path.join(self.tmpdirname, "index.faiss")
|
||||
dataset.get_index("embeddings").save(os.path.join(self.tmpdirname, "index.faiss"))
|
||||
dataset.drop_index("embeddings")
|
||||
dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset"))
|
||||
del dataset
|
||||
retriever = RagPyTorchDistributedRetriever(
|
||||
config,
|
||||
question_encoder_tokenizer=self.get_dpr_tokenizer(),
|
||||
generator_tokenizer=self.get_bart_tokenizer(),
|
||||
)
|
||||
else:
|
||||
retriever = RagPyTorchDistributedRetriever(
|
||||
config,
|
||||
question_encoder_tokenizer=self.get_dpr_tokenizer(),
|
||||
generator_tokenizer=self.get_bart_tokenizer(),
|
||||
index=CustomHFIndex(config.retrieval_vector_size, dataset),
|
||||
)
|
||||
if init_retrieval:
|
||||
retriever.init_retrieval(port)
|
||||
return retriever
|
||||
|
||||
def test_pytorch_distributed_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_pytorch_distributed_retriever(init_retrieval=True)
|
||||
|
|
@ -154,3 +192,33 @@ class RagRetrieverTest(TestCase):
|
|||
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
|
||||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
|
||||
def test_custom_hf_index_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_custom_hf_index_retriever(init_retrieval=True, from_disk=False)
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
|
||||
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
||||
self.assertEqual(len(doc_dicts), 2)
|
||||
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
|
||||
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
|
||||
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
|
||||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
|
||||
def test_custom_pytorch_distributed_retriever_retrieve_from_disk(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_custom_hf_index_retriever(init_retrieval=True, from_disk=True)
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
|
||||
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
||||
self.assertEqual(len(doc_dicts), 2)
|
||||
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
|
||||
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
|
||||
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
|
||||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
|
|
|
|||
199
examples/rag/use_own_knowledge_dataset.py
Normal file
199
examples/rag/use_own_knowledge_dataset.py
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
import faiss
|
||||
from transformers import (
|
||||
DPRContextEncoder,
|
||||
DPRContextEncoderTokenizerFast,
|
||||
HfArgumentParser,
|
||||
RagRetriever,
|
||||
RagSequenceForGeneration,
|
||||
RagTokenizer,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
torch.set_grad_enabled(False)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def split_text(text: str, n=100, character=" ") -> List[str]:
|
||||
"""Split the text every ``n``-th occurence of ``character``"""
|
||||
text = text.split(character)
|
||||
return [character.join(text[i : i + n]).strip() for i in range(0, len(text), n)]
|
||||
|
||||
|
||||
def split_documents(documents: dict) -> dict:
|
||||
"""Split documents into passages"""
|
||||
titles, texts = [], []
|
||||
for title, text in zip(documents["title"], documents["text"]):
|
||||
for passage in split_text(text):
|
||||
titles.append(title)
|
||||
texts.append(passage)
|
||||
return {"title": titles, "text": texts}
|
||||
|
||||
|
||||
def embed(documents: dict, ctx_encoder: DPRContextEncoder, ctx_tokenizer: DPRContextEncoderTokenizerFast) -> dict:
|
||||
"""Compute the DPR embeddings of document passages"""
|
||||
input_ids = ctx_tokenizer(
|
||||
documents["title"], documents["text"], truncation=True, padding="longest", return_tensors="pt"
|
||||
)["input_ids"]
|
||||
embeddings = ctx_encoder(input_ids.to(device=device), return_dict=True).pooler_output
|
||||
return {"embeddings": embeddings.detach().cpu().numpy()}
|
||||
|
||||
|
||||
def main(
|
||||
rag_example_args: "RagExampleArguments",
|
||||
processing_args: "ProcessingArguments",
|
||||
index_hnsw_args: "IndexHnswArguments",
|
||||
):
|
||||
|
||||
######################################
|
||||
logger.info("Step 1 - Create the dataset")
|
||||
######################################
|
||||
|
||||
# The dataset needed for RAG must have three columns:
|
||||
# - title (string): title of the document
|
||||
# - text (string): text of a passage of the document
|
||||
# - embeddings (array of dimension d): DPR representation of the passage
|
||||
|
||||
# Let's say you have documents in tab-separated csv files with columns "title" and "text"
|
||||
assert os.path.isfile(rag_example_args.csv_path), "Please provide a valid path to a csv file"
|
||||
|
||||
# You can load a Dataset object this way
|
||||
dataset = load_dataset(
|
||||
"csv", data_files=[rag_example_args.csv_path], split="train", delimiter="\t", column_names=["title", "text"]
|
||||
)
|
||||
|
||||
# More info about loading csv files in the documentation: https://huggingface.co/docs/datasets/loading_datasets.html?highlight=csv#csv-files
|
||||
|
||||
# Then split the documents into passages of 100 words
|
||||
dataset = dataset.map(split_documents, batched=True, num_proc=processing_args.num_proc)
|
||||
|
||||
# And compute the embeddings
|
||||
ctx_encoder = DPRContextEncoder.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name).to(device=device)
|
||||
ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name)
|
||||
dataset = dataset.map(
|
||||
partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer),
|
||||
batched=True,
|
||||
batch_size=processing_args.batch_size,
|
||||
)
|
||||
|
||||
# And finally save your dataset
|
||||
passages_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset")
|
||||
dataset.save_to_disk(passages_path)
|
||||
# from datasets import load_from_disk
|
||||
# dataset = load_from_disk(passages_path) # to reload the dataset
|
||||
|
||||
######################################
|
||||
logger.info("Step 2 - Index the dataset")
|
||||
######################################
|
||||
|
||||
# Let's use the Faiss implementation of HNSW for fast approximate nearest neighbor search
|
||||
index = faiss.IndexHNSWFlat(index_hnsw_args.d, index_hnsw_args.m, faiss.METRIC_INNER_PRODUCT)
|
||||
dataset.add_faiss_index("embeddings", custom_index=index)
|
||||
|
||||
# And save the index
|
||||
index_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset_hnsw_index.faiss")
|
||||
dataset.get_index("embeddings").save(index_path)
|
||||
# dataset.load_faiss_index("embeddings", index_path) # to reload the index
|
||||
|
||||
######################################
|
||||
logger.info("Step 3 - Load RAG")
|
||||
######################################
|
||||
|
||||
# Easy way to load the model
|
||||
retriever = RagRetriever.from_pretrained(
|
||||
rag_example_args.rag_model_name, index_name="custom", indexed_dataset=dataset
|
||||
)
|
||||
model = RagSequenceForGeneration.from_pretrained(rag_example_args.rag_model_name, retriever=retriever)
|
||||
tokenizer = RagTokenizer.from_pretrained(rag_example_args.rag_model_name)
|
||||
|
||||
# For distributed fine-tuning you'll need to provide the paths instead, as the dataset and the index are loaded separately.
|
||||
# retriever = RagRetriever.from_pretrained(rag_model_name, index_name="custom", passages_path=passages_path, index_path=index_path)
|
||||
|
||||
######################################
|
||||
logger.info("Step 4 - Have fun")
|
||||
######################################
|
||||
|
||||
question = rag_example_args.question or "What does Moses' rod turn into ?"
|
||||
input_ids = tokenizer.question_encoder(question, return_tensors="pt")["input_ids"]
|
||||
generated = model.generate(input_ids)
|
||||
generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
|
||||
logger.info("Q: " + question)
|
||||
logger.info("A: " + generated_string)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RagExampleArguments:
|
||||
csv_path: str = field(
|
||||
default=str(Path(__file__).parent / "test_data" / "my_knowledge_dataset.csv"),
|
||||
metadata={"help": "Path to a tab-separated csv file with columns 'title' and 'text'"},
|
||||
)
|
||||
question: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Question that is passed as input to RAG. Default is 'What does Moses' rod turn into ?'."},
|
||||
)
|
||||
rag_model_name: str = field(
|
||||
default="facebook/rag-sequence-nq",
|
||||
metadata={"help": "The RAG model to use. Either 'facebook/rag-sequence-nq' or 'facebook/rag-token-nq'"},
|
||||
)
|
||||
dpr_ctx_encoder_model_name: str = field(
|
||||
default="facebook/dpr-ctx_encoder-multiset-base",
|
||||
metadata={
|
||||
"help": "The DPR context encoder model to use. Either 'facebook/dpr-ctx_encoder-single-nq-base' or 'facebook/dpr-ctx_encoder-multiset-base'"
|
||||
},
|
||||
)
|
||||
output_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to a directory where the dataset passages and the index will be saved"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessingArguments:
|
||||
num_proc: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The number of processes to use to split the documents into passages. Default is single process."
|
||||
},
|
||||
)
|
||||
batch_size: int = field(
|
||||
default=16,
|
||||
metadata={
|
||||
"help": "The batch size to use when computing the passages embeddings using the DPR context encoder."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexHnswArguments:
|
||||
d: int = field(
|
||||
default=768,
|
||||
metadata={"help": "The dimension of the embeddings to pass to the HNSW Faiss index."},
|
||||
)
|
||||
m: int = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": "The number of bi-directional links created for every new element during the HNSW index construction."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
parser = HfArgumentParser((RagExampleArguments, ProcessingArguments, IndexHnswArguments))
|
||||
rag_example_args, processing_args, index_hnsw_args = parser.parse_args_into_dataclasses()
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
rag_example_args.output_dir = rag_example_args.output_dir or tmp_dir
|
||||
main(rag_example_args, processing_args, index_hnsw_args)
|
||||
|
|
@ -24,6 +24,9 @@ DPR_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
|||
"facebook/dpr-ctx_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/dpr-ctx_encoder-single-nq-base/config.json",
|
||||
"facebook/dpr-question_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/dpr-question_encoder-single-nq-base/config.json",
|
||||
"facebook/dpr-reader-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/dpr-reader-single-nq-base/config.json",
|
||||
"facebook/dpr-ctx_encoder-multiset-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/dpr-ctx_encoder-multiset-base/config.json",
|
||||
"facebook/dpr-question_encoder-multiset-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/dpr-question_encoder-multiset-base/config.json",
|
||||
"facebook/dpr-reader-multiset-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/dpr-reader-multiset-base/config.json",
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ RAG_CONFIG_DOC = r"""
|
|||
Retrieval batch size, defined as the number of queries issues concurrently to the faiss index excapsulated
|
||||
:class:`~transformers.RagRetriever`.
|
||||
dataset (:obj:`str`, `optional`, defaults to :obj:`"wiki_dpr"`):
|
||||
A dataset identifier of the indexed dataset on HuggingFace AWS bucket (list all available datasets and
|
||||
A dataset identifier of the indexed dataset in HuggingFace Datasets (list all available datasets and
|
||||
ids using :obj:`datasets.list_datasets()`).
|
||||
dataset_split (:obj:`str`, `optional`, defaults to :obj:`"train"`)
|
||||
Which split of the :obj:`dataset` to load.
|
||||
|
|
|
|||
|
|
@ -35,12 +35,15 @@ _CONFIG_FOR_DOC = "DPRConfig"
|
|||
|
||||
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"facebook/dpr-ctx_encoder-single-nq-base",
|
||||
"facebook/dpr-ctx_encoder-multiset-base",
|
||||
]
|
||||
DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"facebook/dpr-question_encoder-single-nq-base",
|
||||
"facebook/dpr-question_encoder-multiset-base",
|
||||
]
|
||||
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"facebook/dpr-reader-single-nq-base",
|
||||
"facebook/dpr-reader-multiset-base",
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ from .utils import logging
|
|||
|
||||
|
||||
if is_datasets_available():
|
||||
from datasets import load_dataset
|
||||
from datasets import Dataset, load_dataset, load_from_disk
|
||||
|
||||
if is_faiss_available():
|
||||
import faiss
|
||||
|
|
@ -93,7 +93,7 @@ class Index:
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
class LegacyIndex:
|
||||
class LegacyIndex(Index):
|
||||
"""
|
||||
An index which can be deserialized from the files built using https://github.com/facebookresearch/DPR.
|
||||
We use default faiss index parameters as specified in that repository.
|
||||
|
|
@ -115,7 +115,7 @@ class LegacyIndex:
|
|||
self.passages = self._load_passages()
|
||||
self.vector_size = vector_size
|
||||
self.index = None
|
||||
self._index_initialize = False
|
||||
self._index_initialized = False
|
||||
|
||||
def _resolve_path(self, index_path, filename):
|
||||
assert os.path.isdir(index_path) or is_remote_url(index_path), "Please specify a valid ``index_path``."
|
||||
|
|
@ -157,7 +157,7 @@ class LegacyIndex:
|
|||
), "Deserialized index_id_to_db_id should match faiss index size"
|
||||
|
||||
def is_initialized(self):
|
||||
return self._index_initialize
|
||||
return self._index_initialized
|
||||
|
||||
def init_index(self):
|
||||
index = faiss.IndexHNSWFlat(self.vector_size + 1, 512)
|
||||
|
|
@ -165,7 +165,7 @@ class LegacyIndex:
|
|||
index.hnsw.efConstruction = 200
|
||||
self.index = index
|
||||
self._deserialize_index()
|
||||
self._index_initialize = True
|
||||
self._index_initialized = True
|
||||
|
||||
def get_doc_dicts(self, doc_ids: np.array):
|
||||
doc_list = []
|
||||
|
|
@ -190,65 +190,34 @@ class LegacyIndex:
|
|||
return np.array(ids), np.array(vectors)
|
||||
|
||||
|
||||
class HFIndex:
|
||||
"""
|
||||
A wrapper around an instance of :class:`~datasets.Datasets`. If ``index_path`` is set to ``None``,
|
||||
we load the pre-computed index available with the :class:`~datasets.arrow_dataset.Dataset`, otherwise, we load the index from the indicated path on disk.
|
||||
|
||||
Args:
|
||||
dataset (:obj:`str`, optional, defaults to ``wiki_dpr``):
|
||||
A datatset identifier of the indexed dataset on HuggingFace AWS bucket (list all available datasets and ids with ``datasets.list_datasets()``).
|
||||
dataset_split (:obj:`str`, optional, defaults to ``train``)
|
||||
Which split of the ``dataset`` to load.
|
||||
index_name (:obj:`str`, optional, defaults to ``train``)
|
||||
The index_name of the index associated with the ``dataset``. The index loaded from ``index_path`` will be saved under this name.
|
||||
index_path (:obj:`str`, optional, defaults to ``None``)
|
||||
The path to the serialized faiss index on disk.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_name: str,
|
||||
dataset_split: str,
|
||||
index_name: str,
|
||||
vector_size: int,
|
||||
index_path: Optional[str] = None,
|
||||
use_dummy_dataset=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dataset_name = dataset_name
|
||||
self.dataset_split = dataset_split
|
||||
self.index_name = index_name
|
||||
class HFIndexBase(Index):
|
||||
def __init__(self, vector_size, dataset, index_initialized=False):
|
||||
self.vector_size = vector_size
|
||||
self.index_path = index_path
|
||||
self.use_dummy_dataset = use_dummy_dataset
|
||||
self._index_initialize = False
|
||||
self.dataset = dataset
|
||||
self._index_initialized = index_initialized
|
||||
self._check_dataset_format(with_index=index_initialized)
|
||||
dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True)
|
||||
|
||||
logger.info("Loading passages from {}".format(self.dataset_name))
|
||||
self.dataset = load_dataset(
|
||||
self.dataset_name, with_index=False, split=self.dataset_split, dummy=self.use_dummy_dataset
|
||||
)
|
||||
self.dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True)
|
||||
|
||||
def is_initialized(self):
|
||||
return self._index_initialize
|
||||
def _check_dataset_format(self, with_index: bool):
|
||||
if not isinstance(self.dataset, Dataset):
|
||||
raise ValueError("Dataset should be a datasets.Dataset object, but got {}".format(type(self.dataset)))
|
||||
if len({"title", "text", "embeddings"} - set(self.dataset.column_names)) > 0:
|
||||
raise ValueError(
|
||||
"Dataset should be a dataset with the following columns: "
|
||||
"title (str), text (str) and embeddings (arrays of dimension vector_size), "
|
||||
"but got columns {}".format(self.dataset.column_names)
|
||||
)
|
||||
if with_index and "embeddings" not in self.dataset.list_indexes():
|
||||
raise ValueError(
|
||||
"Missing faiss index in the dataset. Make sure you called `dataset.add_faiss_index` to compute it "
|
||||
"or `dataset.load_faiss_index` to load one from the disk."
|
||||
)
|
||||
|
||||
def init_index(self):
|
||||
if self.index_path is not None:
|
||||
logger.info("Loading index from {}".format(self.index_path))
|
||||
self.index.load_faiss_index(index_name=self.index_name, file=self.index_path)
|
||||
else:
|
||||
logger.info("Loading index from {}".format(self.dataset_name + " with index name " + self.index_name))
|
||||
self.dataset = load_dataset(
|
||||
self.dataset_name,
|
||||
with_embeddings=True,
|
||||
with_index=True,
|
||||
split=self.dataset_split,
|
||||
index_name=self.index_name,
|
||||
dummy=self.use_dummy_dataset,
|
||||
)
|
||||
self.dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True)
|
||||
self._index_initialize = True
|
||||
raise NotImplementedError()
|
||||
|
||||
def is_initialized(self):
|
||||
return self._index_initialized
|
||||
|
||||
def get_doc_dicts(self, doc_ids: np.ndarray) -> List[dict]:
|
||||
return [self.dataset[doc_ids[i].tolist()] for i in range(doc_ids.shape[0])]
|
||||
|
|
@ -263,6 +232,100 @@ class HFIndex:
|
|||
return np.array(ids), np.array(vectors) # shapes (batch_size, n_docs) and (batch_size, n_docs, d)
|
||||
|
||||
|
||||
class CanonicalHFIndex(HFIndexBase):
|
||||
"""
|
||||
A wrapper around an instance of :class:`~datasets.Datasets`. If ``index_path`` is set to ``None``,
|
||||
we load the pre-computed index available with the :class:`~datasets.arrow_dataset.Dataset`, otherwise, we load the index from the indicated path on disk.
|
||||
|
||||
Args:
|
||||
vector_size (:obj:`int`): the dimension of the passages embeddings used by the index
|
||||
dataset_name (:obj:`str`, optional, defaults to ``wiki_dpr``):
|
||||
A datatset identifier of the indexed dataset on HuggingFace AWS bucket (list all available datasets and ids with ``datasets.list_datasets()``).
|
||||
dataset_split (:obj:`str`, optional, defaults to ``train``)
|
||||
Which split of the ``dataset`` to load.
|
||||
index_name (:obj:`str`, optional, defaults to ``train``)
|
||||
The index_name of the index associated with the ``dataset``. The index loaded from ``index_path`` will be saved under this name.
|
||||
index_path (:obj:`str`, optional, defaults to ``None``)
|
||||
The path to the serialized faiss index on disk.
|
||||
use_dummy_dataset (:obj:`bool`, optional, defaults to ``False``): If True, use the dummy configuration of the dataset for tests.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_size: int,
|
||||
dataset_name: str = "wiki_dpr",
|
||||
dataset_split: str = "train",
|
||||
index_name: Optional[str] = None,
|
||||
index_path: Optional[str] = None,
|
||||
use_dummy_dataset=False,
|
||||
):
|
||||
if int(index_path is None) + int(index_name is None) != 1:
|
||||
raise ValueError("Please provide `index_name` or `index_path`.")
|
||||
self.dataset_name = dataset_name
|
||||
self.dataset_split = dataset_split
|
||||
self.index_name = index_name
|
||||
self.index_path = index_path
|
||||
self.use_dummy_dataset = use_dummy_dataset
|
||||
logger.info("Loading passages from {}".format(self.dataset_name))
|
||||
dataset = load_dataset(
|
||||
self.dataset_name, with_index=False, split=self.dataset_split, dummy=self.use_dummy_dataset
|
||||
)
|
||||
super().__init__(vector_size, dataset, index_initialized=False)
|
||||
|
||||
def init_index(self):
|
||||
if self.index_path is not None:
|
||||
logger.info("Loading index from {}".format(self.index_path))
|
||||
self.dataset.load_faiss_index("embeddings", file=self.index_path)
|
||||
else:
|
||||
logger.info("Loading index from {}".format(self.dataset_name + " with index name " + self.index_name))
|
||||
self.dataset = load_dataset(
|
||||
self.dataset_name,
|
||||
with_embeddings=True,
|
||||
with_index=True,
|
||||
split=self.dataset_split,
|
||||
index_name=self.index_name,
|
||||
dummy=self.use_dummy_dataset,
|
||||
)
|
||||
self.dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True)
|
||||
self._index_initialized = True
|
||||
|
||||
|
||||
class CustomHFIndex(HFIndexBase):
|
||||
"""
|
||||
A wrapper around an instance of :class:`~datasets.Datasets`.
|
||||
The dataset and the index are both loaded from the indicated paths on disk.
|
||||
|
||||
Args:
|
||||
vector_size (:obj:`int`): the dimension of the passages embeddings used by the index
|
||||
dataset_path (:obj:`str`):
|
||||
The path to the serialized dataset on disk.
|
||||
The dataset should have 3 columns: title (str), text (str) and embeddings (arrays of dimension vector_size)
|
||||
index_path (:obj:`str`)
|
||||
The path to the serialized faiss index on disk.
|
||||
"""
|
||||
|
||||
def __init__(self, vector_size: int, dataset, index_path=None):
|
||||
super().__init__(vector_size, dataset, index_initialized=index_path is None)
|
||||
self.index_path = index_path
|
||||
|
||||
@classmethod
|
||||
def load_from_disk(cls, vector_size, dataset_path, index_path):
|
||||
logger.info("Loading passages from {}".format(dataset_path))
|
||||
if dataset_path is None or index_path is None:
|
||||
raise ValueError(
|
||||
"Please provide ``dataset_path`` and ``index_path`` after calling ``dataset.save_to_disk(dataset_path)`` "
|
||||
"and ``dataset.get_index('embeddings').save(index_path)``."
|
||||
)
|
||||
dataset = load_from_disk(dataset_path)
|
||||
return cls(vector_size=vector_size, dataset=dataset, index_path=index_path)
|
||||
|
||||
def init_index(self):
|
||||
if not self.is_initialized():
|
||||
logger.info("Loading index from {}".format(self.index_path))
|
||||
self.dataset.load_faiss_index("embeddings", file=self.index_path)
|
||||
self._index_initialized = True
|
||||
|
||||
|
||||
class RagRetriever:
|
||||
"""
|
||||
Retriever used to get documents from vector queries.
|
||||
|
|
@ -271,34 +334,46 @@ class RagRetriever:
|
|||
Args:
|
||||
config (:class:`~transformers.RagConfig`):
|
||||
The configuration of the RAG model this Retriever is used with. Contains parameters indicating which ``Index`` to build.
|
||||
You can load your own custom dataset with ``config.index_name="custom"`` or use a canonical one (default) from the datasets library
|
||||
with ``config.index_name="wiki_dpr"`` for example.
|
||||
question_encoder_tokenizer (:class:`~transformers.PreTrainedTokenizer`):
|
||||
The tokenizer that was used to tokenize the question.
|
||||
It is used to decode the question and then use the generator_tokenizer.
|
||||
generator_tokenizer (:class:`~transformers.PreTrainedTokenizer`):
|
||||
The tokenizer used for the generator part of the RagModel.
|
||||
index (:class:`~transformers.retrieval_rag.Index`, optional, defaults to the one defined by the configuration):
|
||||
If specified, use this index instead of the one built using the configuration
|
||||
|
||||
Examples::
|
||||
|
||||
>>> # To load the default "wiki_dpr" dataset with 21M passages from wikipedia (index name is 'compressed' or 'exact')
|
||||
>>> from transformers import RagRetriever
|
||||
>>> retriever = RagRetriever.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base', dataset="wiki_dpr", index_name='compressed')
|
||||
|
||||
>>> # To load your own indexed dataset built with the datasets library. More info on how to build the indexed dataset in examples/rag/use_own_knowledge_dataset.py
|
||||
>>> from transformers import RagRetriever
|
||||
>>> dataset = ... # dataset must be a datasets.Datasets object with columns "title", "text" and "embeddings", and it must have a faiss index
|
||||
>>> retriever = RagRetriever.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base', indexed_dataset=dataset)
|
||||
|
||||
>>> # To load your own indexed dataset built with the datasets library that was saved on disk. More info in examples/rag/use_own_knowledge_dataset.py
|
||||
>>> from transformers import RagRetriever
|
||||
>>> dataset_path = "path/to/my/dataset" # dataset saved via `dataset.save_to_disk(...)`
|
||||
>>> index_path = "path/to/my/index.faiss" # faiss index saved via `dataset.get_index("embeddings").save(...)`
|
||||
>>> retriever = RagRetriever.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base', index_name='custom', passages_path=dataset_path, index_path=index_path)
|
||||
|
||||
>>> # To load the legacy index built originally for Rag's paper
|
||||
>>> from transformers import RagRetriever
|
||||
>>> retriever = RagRetriever.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base', index_name='legacy')
|
||||
|
||||
"""
|
||||
|
||||
_init_retrieval = True
|
||||
|
||||
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer):
|
||||
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None):
|
||||
requires_datasets(self)
|
||||
requires_faiss(self)
|
||||
super().__init__()
|
||||
self.index = (
|
||||
LegacyIndex(
|
||||
config.retrieval_vector_size,
|
||||
config.index_path or LEGACY_INDEX_PATH,
|
||||
)
|
||||
if config.index_name == "legacy"
|
||||
else HFIndex(
|
||||
config.dataset,
|
||||
config.dataset_split,
|
||||
config.index_name,
|
||||
config.retrieval_vector_size,
|
||||
config.index_path,
|
||||
config.use_dummy_dataset,
|
||||
)
|
||||
)
|
||||
self.index = index or self._build_index(config)
|
||||
self.generator_tokenizer = generator_tokenizer
|
||||
self.question_encoder_tokenizer = question_encoder_tokenizer
|
||||
|
||||
|
|
@ -309,19 +384,62 @@ class RagRetriever:
|
|||
if self._init_retrieval:
|
||||
self.init_retrieval()
|
||||
|
||||
@staticmethod
|
||||
def _build_index(config):
|
||||
if config.index_name == "legacy":
|
||||
return LegacyIndex(
|
||||
config.retrieval_vector_size,
|
||||
config.index_path or LEGACY_INDEX_PATH,
|
||||
)
|
||||
elif config.index_name == "custom":
|
||||
return CustomHFIndex.load_from_disk(
|
||||
vector_size=config.retrieval_vector_size,
|
||||
dataset_path=config.passages_path,
|
||||
index_path=config.index_path,
|
||||
)
|
||||
else:
|
||||
return CanonicalHFIndex(
|
||||
vector_size=config.retrieval_vector_size,
|
||||
dataset_name=config.dataset,
|
||||
dataset_split=config.dataset_split,
|
||||
index_name=config.index_name,
|
||||
index_path=config.index_path,
|
||||
use_dummy_dataset=config.use_dummy_dataset,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, retriever_name_or_path, **kwargs):
|
||||
def from_pretrained(cls, retriever_name_or_path, indexed_dataset=None, **kwargs):
|
||||
requires_datasets(cls)
|
||||
requires_faiss(cls)
|
||||
config = RagConfig.from_pretrained(retriever_name_or_path, **kwargs)
|
||||
config = kwargs.pop("config", None) or RagConfig.from_pretrained(retriever_name_or_path, **kwargs)
|
||||
rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config)
|
||||
question_encoder_tokenizer = rag_tokenizer.question_encoder
|
||||
generator_tokenizer = rag_tokenizer.generator
|
||||
if indexed_dataset is not None:
|
||||
config.index_name = "custom"
|
||||
index = CustomHFIndex(config.retrieval_vector_size, indexed_dataset)
|
||||
else:
|
||||
index = cls._build_index(config)
|
||||
return cls(
|
||||
config, question_encoder_tokenizer=question_encoder_tokenizer, generator_tokenizer=generator_tokenizer
|
||||
config,
|
||||
question_encoder_tokenizer=question_encoder_tokenizer,
|
||||
generator_tokenizer=generator_tokenizer,
|
||||
index=index,
|
||||
)
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
if isinstance(self.index, CustomHFIndex):
|
||||
if self.config.index_path is None:
|
||||
index_path = os.path.join(save_directory, "hf_dataset_index.faiss")
|
||||
self.index.dataset.get_index("embeddings").save(index_path)
|
||||
self.config.index_path = index_path
|
||||
if self.config.passages_path is None:
|
||||
passages_path = os.path.join(save_directory, "hf_dataset")
|
||||
# datasets don't support save_to_disk with indexes right now
|
||||
faiss_index = self.index.dataset._indexes.pop("embeddings")
|
||||
self.index.dataset.save_to_disk(passages_path)
|
||||
self.index.dataset._indexes["embeddings"] = faiss_index
|
||||
self.config.passages_path = passages_path
|
||||
self.config.save_pretrained(save_directory)
|
||||
rag_tokenizer = RagTokenizer(
|
||||
question_encoder=self.question_encoder_tokenizer,
|
||||
|
|
|
|||
|
|
@ -26,43 +26,64 @@ from .utils import logging
|
|||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"}
|
||||
|
||||
CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {
|
||||
"facebook/dpr-ctx_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
|
||||
}
|
||||
"facebook/dpr-ctx_encoder-multiset-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
|
||||
},
|
||||
"tokenizer_file": {
|
||||
"facebook/dpr-ctx_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tokenizer.json",
|
||||
"facebook/dpr-ctx_encoder-multiset-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tokenizer.json",
|
||||
},
|
||||
}
|
||||
QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {
|
||||
"facebook/dpr-question_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
|
||||
}
|
||||
"facebook/dpr-question_encoder-multiset-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
|
||||
},
|
||||
"tokenizer_file": {
|
||||
"facebook/dpr-question_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tokenizer.json",
|
||||
"facebook/dpr-question_encoder-multiset-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tokenizer.json",
|
||||
},
|
||||
}
|
||||
READER_PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {
|
||||
"facebook/dpr-reader-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
|
||||
}
|
||||
"facebook/dpr-reader-multiset-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
|
||||
},
|
||||
"tokenizer_file": {
|
||||
"facebook/dpr-reader-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tokenizer.json",
|
||||
"facebook/dpr-reader-multiset-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tokenizer.json",
|
||||
},
|
||||
}
|
||||
|
||||
CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
"facebook/dpr-ctx_encoder-single-nq-base": 512,
|
||||
"facebook/dpr-ctx_encoder-multiset-base": 512,
|
||||
}
|
||||
QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
"facebook/dpr-question_encoder-single-nq-base": 512,
|
||||
"facebook/dpr-question_encoder-multiset-base": 512,
|
||||
}
|
||||
READER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
"facebook/dpr-reader-single-nq-base": 512,
|
||||
"facebook/dpr-reader-multiset-base": 512,
|
||||
}
|
||||
|
||||
|
||||
CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
|
||||
"facebook/dpr-ctx_encoder-single-nq-base": {"do_lower_case": True},
|
||||
"facebook/dpr-ctx_encoder-multiset-base": {"do_lower_case": True},
|
||||
}
|
||||
QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
|
||||
"facebook/dpr-question_encoder-single-nq-base": {"do_lower_case": True},
|
||||
"facebook/dpr-question_encoder-multiset-base": {"do_lower_case": True},
|
||||
}
|
||||
READER_PRETRAINED_INIT_CONFIGURATION = {
|
||||
"facebook/dpr-reader-single-nq-base": {"do_lower_case": True},
|
||||
"facebook/dpr-reader-multiset-base": {"do_lower_case": True},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -32,47 +32,59 @@ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.jso
|
|||
CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {
|
||||
"facebook/dpr-ctx_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
|
||||
"facebook/dpr-ctx_encoder-multiset-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
|
||||
},
|
||||
"tokenizer_file": {
|
||||
"facebook/dpr-ctx_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tokenizer.json",
|
||||
"facebook/dpr-ctx_encoder-multiset-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tokenizer.json",
|
||||
},
|
||||
}
|
||||
QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {
|
||||
"facebook/dpr-question_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
|
||||
"facebook/dpr-question_encoder-multiset-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
|
||||
},
|
||||
"tokenizer_file": {
|
||||
"facebook/dpr-question_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tokenizer.json",
|
||||
"facebook/dpr-question_encoder-multiset-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tokenizer.json",
|
||||
},
|
||||
}
|
||||
READER_PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {
|
||||
"facebook/dpr-reader-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
|
||||
"facebook/dpr-reader-multiset-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
|
||||
},
|
||||
"tokenizer_file": {
|
||||
"facebook/dpr-reader-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tokenizer.json",
|
||||
"facebook/dpr-reader-multiset-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tokenizer.json",
|
||||
},
|
||||
}
|
||||
|
||||
CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
"facebook/dpr-ctx_encoder-single-nq-base": 512,
|
||||
"facebook/dpr-ctx_encoder-multiset-base": 512,
|
||||
}
|
||||
QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
"facebook/dpr-question_encoder-single-nq-base": 512,
|
||||
"facebook/dpr-question_encoder-multiset-base": 512,
|
||||
}
|
||||
READER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
"facebook/dpr-reader-single-nq-base": 512,
|
||||
"facebook/dpr-reader-multiset-base": 512,
|
||||
}
|
||||
|
||||
|
||||
CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
|
||||
"facebook/dpr-ctx_encoder-single-nq-base": {"do_lower_case": True},
|
||||
"facebook/dpr-ctx_encoder-multiset-base": {"do_lower_case": True},
|
||||
}
|
||||
QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
|
||||
"facebook/dpr-question_encoder-single-nq-base": {"do_lower_case": True},
|
||||
"facebook/dpr-question_encoder-multiset-base": {"do_lower_case": True},
|
||||
}
|
||||
READER_PRETRAINED_INIT_CONFIGURATION = {
|
||||
"facebook/dpr-reader-single-nq-base": {"do_lower_case": True},
|
||||
"facebook/dpr-reader-multiset-base": {"do_lower_case": True},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import faiss
|
|||
from transformers.configuration_bart import BartConfig
|
||||
from transformers.configuration_dpr import DPRConfig
|
||||
from transformers.configuration_rag import RagConfig
|
||||
from transformers.retrieval_rag import RagRetriever
|
||||
from transformers.retrieval_rag import CustomHFIndex, RagRetriever
|
||||
from transformers.testing_utils import (
|
||||
require_datasets,
|
||||
require_faiss,
|
||||
|
|
@ -103,7 +103,7 @@ class RagRetrieverTest(TestCase):
|
|||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def get_dummy_hf_index_retriever(self):
|
||||
def get_dummy_dataset(self):
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"id": ["0", "1"],
|
||||
|
|
@ -113,6 +113,10 @@ class RagRetrieverTest(TestCase):
|
|||
}
|
||||
)
|
||||
dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)
|
||||
return dataset
|
||||
|
||||
def get_dummy_canonical_hf_index_retriever(self):
|
||||
dataset = self.get_dummy_dataset()
|
||||
config = RagConfig(
|
||||
retrieval_vector_size=self.retrieval_vector_size,
|
||||
question_encoder=DPRConfig().to_dict(),
|
||||
|
|
@ -127,6 +131,35 @@ class RagRetrieverTest(TestCase):
|
|||
)
|
||||
return retriever
|
||||
|
||||
def get_dummy_custom_hf_index_retriever(self, from_disk: bool):
|
||||
dataset = self.get_dummy_dataset()
|
||||
config = RagConfig(
|
||||
retrieval_vector_size=self.retrieval_vector_size,
|
||||
question_encoder=DPRConfig().to_dict(),
|
||||
generator=BartConfig().to_dict(),
|
||||
index_name="custom",
|
||||
)
|
||||
if from_disk:
|
||||
config.passages_path = os.path.join(self.tmpdirname, "dataset")
|
||||
config.index_path = os.path.join(self.tmpdirname, "index.faiss")
|
||||
dataset.get_index("embeddings").save(os.path.join(self.tmpdirname, "index.faiss"))
|
||||
dataset.drop_index("embeddings")
|
||||
dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset"))
|
||||
del dataset
|
||||
retriever = RagRetriever(
|
||||
config,
|
||||
question_encoder_tokenizer=self.get_dpr_tokenizer(),
|
||||
generator_tokenizer=self.get_bart_tokenizer(),
|
||||
)
|
||||
else:
|
||||
retriever = RagRetriever(
|
||||
config,
|
||||
question_encoder_tokenizer=self.get_dpr_tokenizer(),
|
||||
generator_tokenizer=self.get_bart_tokenizer(),
|
||||
index=CustomHFIndex(config.retrieval_vector_size, dataset),
|
||||
)
|
||||
return retriever
|
||||
|
||||
def get_dummy_legacy_index_retriever(self):
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
|
|
@ -152,16 +185,15 @@ class RagRetrieverTest(TestCase):
|
|||
generator=BartConfig().to_dict(),
|
||||
index_name="legacy",
|
||||
index_path=self.tmpdirname,
|
||||
passages_path=self.tmpdirname,
|
||||
)
|
||||
retriever = RagRetriever(
|
||||
config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer()
|
||||
)
|
||||
return retriever
|
||||
|
||||
def test_hf_index_retriever_retrieve(self):
|
||||
def test_canonical_hf_index_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_hf_index_retriever()
|
||||
retriever = self.get_dummy_canonical_hf_index_retriever()
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
|
|
@ -174,10 +206,73 @@ class RagRetrieverTest(TestCase):
|
|||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
|
||||
def test_save_and_from_pretrained(self):
|
||||
retriever = self.get_dummy_hf_index_retriever()
|
||||
def test_canonical_hf_index_retriever_save_and_from_pretrained(self):
|
||||
retriever = self.get_dummy_canonical_hf_index_retriever()
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
with patch("transformers.retrieval_rag.load_dataset") as mock_load_dataset:
|
||||
mock_load_dataset.return_value = self.get_dummy_dataset()
|
||||
retriever.save_pretrained(tmp_dirname)
|
||||
retriever = RagRetriever.from_pretrained(tmp_dirname)
|
||||
self.assertIsInstance(retriever, RagRetriever)
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
out = retriever.retrieve(hidden_states, n_docs=1)
|
||||
self.assertTrue(out is not None)
|
||||
|
||||
def test_custom_hf_index_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_custom_hf_index_retriever(from_disk=False)
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
|
||||
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
||||
self.assertEqual(len(doc_dicts), 2)
|
||||
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
|
||||
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
|
||||
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
|
||||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
|
||||
def test_custom_hf_index_retriever_save_and_from_pretrained(self):
|
||||
retriever = self.get_dummy_custom_hf_index_retriever(from_disk=False)
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
retriever.save_pretrained(tmp_dirname)
|
||||
retriever = RagRetriever.from_pretrained(tmp_dirname)
|
||||
self.assertIsInstance(retriever, RagRetriever)
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
out = retriever.retrieve(hidden_states, n_docs=1)
|
||||
self.assertTrue(out is not None)
|
||||
|
||||
def test_custom_hf_index_retriever_retrieve_from_disk(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_custom_hf_index_retriever(from_disk=True)
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
|
||||
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
||||
self.assertEqual(len(doc_dicts), 2)
|
||||
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
|
||||
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
|
||||
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
|
||||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
|
||||
def test_custom_hf_index_retriever_save_and_from_pretrained_from_disk(self):
|
||||
retriever = self.get_dummy_custom_hf_index_retriever(from_disk=True)
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
retriever.save_pretrained(tmp_dirname)
|
||||
retriever = RagRetriever.from_pretrained(tmp_dirname)
|
||||
self.assertIsInstance(retriever, RagRetriever)
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
out = retriever.retrieve(hidden_states, n_docs=1)
|
||||
self.assertTrue(out is not None)
|
||||
|
||||
def test_legacy_index_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
|
|
@ -194,6 +289,18 @@ class RagRetrieverTest(TestCase):
|
|||
self.assertEqual(doc_dicts[1]["text"][0], "foo") # max inner product is reached with first doc
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
|
||||
def test_legacy_hf_index_retriever_save_and_from_pretrained(self):
|
||||
retriever = self.get_dummy_legacy_index_retriever()
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
retriever.save_pretrained(tmp_dirname)
|
||||
retriever = RagRetriever.from_pretrained(tmp_dirname)
|
||||
self.assertIsInstance(retriever, RagRetriever)
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
out = retriever.retrieve(hidden_states, n_docs=1)
|
||||
self.assertTrue(out is not None)
|
||||
|
||||
@require_torch
|
||||
@require_tokenizers
|
||||
@require_sentencepiece
|
||||
|
|
@ -201,7 +308,7 @@ class RagRetrieverTest(TestCase):
|
|||
import torch
|
||||
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_hf_index_retriever()
|
||||
retriever = self.get_dummy_canonical_hf_index_retriever()
|
||||
question_input_ids = [[5, 7], [10, 11]]
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
|
|
|
|||
Loading…
Reference in a new issue