diff --git a/src/transformers/pipelines.py b/src/transformers/pipelines.py index 27ad1d462..89048c5b6 100755 --- a/src/transformers/pipelines.py +++ b/src/transformers/pipelines.py @@ -1624,7 +1624,17 @@ class QuestionAnsweringArgumentHandler(ArgumentHandler): elif "data" in kwargs: inputs = kwargs["data"] elif "question" in kwargs and "context" in kwargs: - inputs = [{"question": kwargs["question"], "context": kwargs["context"]}] + if isinstance(kwargs["question"], list) and isinstance(kwargs["context"], str): + inputs = [{"question": Q, "context": kwargs["context"]} for Q in kwargs["question"]] + elif isinstance(kwargs["question"], list) and isinstance(kwargs["context"], list): + if len(kwargs["question"]) != len(kwargs["context"]): + raise ValueError("Questions and contexts don't have the same lengths") + + inputs = [{"question": Q, "context": C} for Q, C in zip(kwargs["question"], kwargs["context"])] + elif isinstance(kwargs["question"], str) and isinstance(kwargs["context"], str): + inputs = [{"question": kwargs["question"], "context": kwargs["context"]}] + else: + raise ValueError("Arguments can't be understood") else: raise ValueError("Unknown arguments {}".format(kwargs)) diff --git a/tests/test_pipelines_question_answering.py b/tests/test_pipelines_question_answering.py index 9b25c5734..e977f655b 100644 --- a/tests/test_pipelines_question_answering.py +++ b/tests/test_pipelines_question_answering.py @@ -23,6 +23,17 @@ class QAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase): "question": "In what field is HuggingFace working ?", "context": "HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.", }, + { + "question": ["In what field is HuggingFace working ?", "In what field is HuggingFace working ?"], + "context": "HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.", + }, + { + "question": ["In what field is HuggingFace working ?", "In what field is HuggingFace working ?"], + "context": [ + "HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.", + "HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.", + ], + }, ] def _test_pipeline(self, nlp: Pipeline): @@ -80,6 +91,11 @@ class QAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase): self.assertEqual(len(normalized), 1) self.assertEqual({type(el) for el in normalized}, {SquadExample}) + normalized = qa(question=[Q, Q], context=C) + self.assertEqual(type(normalized), list) + self.assertEqual(len(normalized), 2) + self.assertEqual({type(el) for el in normalized}, {SquadExample}) + normalized = qa({"question": Q, "context": C}) self.assertEqual(type(normalized), list) self.assertEqual(len(normalized), 1) @@ -159,6 +175,26 @@ class QAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase): with self.assertRaises(ValueError): qa([{"question": Q, "context": C}, {"question": Q, "context": ""}]) + with self.assertRaises(ValueError): + qa(question={"This": "Is weird"}, context="This is a context") + + with self.assertRaises(ValueError): + qa(question=[Q, Q], context=[C, C, C]) + + with self.assertRaises(ValueError): + qa(question=[Q, Q, Q], context=[C, C]) + + def test_argument_handler_old_format(self): + qa = QuestionAnsweringArgumentHandler() + + Q = "Where was HuggingFace founded ?" + C = "HuggingFace was founded in Paris" + # Backward compatibility for this + normalized = qa(question=[Q, Q], context=[C, C]) + self.assertEqual(type(normalized), list) + self.assertEqual(len(normalized), 2) + self.assertEqual({type(el) for el in normalized}, {SquadExample}) + def test_argument_handler_error_handling_odd(self): qa = QuestionAnsweringArgumentHandler() with self.assertRaises(ValueError):