From 995a958dd18d4326e608efc3bfc4005acfef8e56 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Mon, 7 Sep 2020 02:03:45 -0500 Subject: [PATCH] feat: allow prefix for any generative model (#5885) * feat: allow padding_text for any generative model * docs(pipelines.py): correct typo * Update src/transformers/pipelines.py Co-authored-by: Sam Shleifer * feat: rename padding_text to prefix * fix: cannot tokenize empty text * fix: pass prefix arg to pipeline * test: add prefix to text-generetation pipeline * style: fix style * style: clean code and variable name more explicit * set arg docstring to optional Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sam Shleifer Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- examples/text-generation/run_generation.py | 14 ++++--- src/transformers/pipelines.py | 46 +++++++++++++--------- tests/test_pipelines.py | 2 + 3 files changed, 38 insertions(+), 24 deletions(-) diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 1b4b6f1e5..ce4348325 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -61,7 +61,7 @@ MODEL_CLASSES = { # Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia # in https://github.com/rusiaaman/XLNet-gen#methodology # and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e -PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family +PREFIX = """In 1991, the remains of Russian Tsar Nicholas II and his family (except for Alexei and Maria) are discovered. The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the remainder of the story. 1883 Western Siberia, @@ -122,12 +122,14 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text): def prepare_xlnet_input(args, _, tokenizer, prompt_text): - prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text + prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX + prompt_text = prefix + prompt_text return prompt_text def prepare_transfoxl_input(args, _, tokenizer, prompt_text): - prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text + prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX + prompt_text = prefix + prompt_text return prompt_text @@ -182,7 +184,8 @@ def main(): parser.add_argument("--k", type=int, default=0) parser.add_argument("--p", type=float, default=0.9) - parser.add_argument("--padding_text", type=str, default="", help="Padding text for Transfo-XL and XLNet.") + parser.add_argument("--prefix", type=str, default="", help="Text added prior to input.") + parser.add_argument("--padding_text", type=str, default="", help="Deprecated, the use of `--prefix` is preferred.") parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.") parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") @@ -241,7 +244,8 @@ def main(): preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs ) else: - encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt") + prefix = args.prefix if args.prefix else args.padding_text + encoded_prompt = tokenizer.encode(prefix + prompt_text, add_special_tokens=False, return_tensors="pt") encoded_prompt = encoded_prompt.to(args.device) if encoded_prompt.size()[-1] == 0: diff --git a/src/transformers/pipelines.py b/src/transformers/pipelines.py index bbaa89c1b..c280d1c02 100755 --- a/src/transformers/pipelines.py +++ b/src/transformers/pipelines.py @@ -752,11 +752,11 @@ class TextGenerationPipeline(Pipeline): `huggingface.co/models `__. """ - # Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia + # Prefix text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia # in https://github.com/rusiaaman/XLNet-gen#methodology # and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e - PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family + XL_PREFIX = """In 1991, the remains of Russian Tsar Nicholas II and his family (except for Alexei and Maria) are discovered. The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the remainder of the story. 1883 Western Siberia, @@ -765,7 +765,7 @@ class TextGenerationPipeline(Pipeline): father initially slaps him for making such an accusation, Rasputin watches as the man is chased outside and beaten. Twenty years later, Rasputin sees a vision of the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous, - with people, even a bishop, begging for his blessing. """ + with people, even a bishop, begging for his blessing. """ ALLOWED_MODELS = [ "XLNetLMHeadModel", @@ -809,7 +809,13 @@ class TextGenerationPipeline(Pipeline): return inputs def __call__( - self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs + self, + *args, + return_tensors=False, + return_text=True, + clean_up_tokenization_spaces=False, + prefix=None, + **generate_kwargs ): """ Complete the prompt(s) given as inputs. @@ -823,6 +829,8 @@ class TextGenerationPipeline(Pipeline): Whether or not to include the decoded texts in the outputs. clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to clean up the potential extra spaces in the text output. + prefix (:obj:`str`, `optional`): + Prefix added to prompt. generate_kwargs: Additional keyword arguments to pass along to the generate method of the model (see the generate method corresponding to your framework `here <./model.html#generative-models>`__). @@ -841,27 +849,27 @@ class TextGenerationPipeline(Pipeline): for prompt_text in text_inputs: # Manage correct placement of the tensors with self.device_placement(): - if self.model.__class__.__name__ in [ + prefix = prefix if prefix is not None else self.model.config.prefix + if prefix is None and self.model.__class__.__name__ in [ "XLNetLMHeadModel", "TransfoXLLMHeadModel", "TFXLNetLMHeadModel", "TFTransfoXLLMHeadModel", ]: - # For XLNet and TransformerXL we had an article to the prompt to give more state to the model. - padding_text = self.PADDING_TEXT + self.tokenizer.eos_token - padding = self._parse_and_tokenize(padding_text, padding=False, add_special_tokens=False) - # This impacts max_length and min_length argument that need adjusting. - padding_length = padding["input_ids"].shape[-1] - if "max_length" in generate_kwargs and generate_kwargs["max_length"] is not None: - generate_kwargs["max_length"] += padding_length - if "min_length" in generate_kwargs and generate_kwargs["min_length"] is not None: - generate_kwargs["min_length"] += padding_length + # For XLNet and TransformerXL we add an article to the prompt to give more state to the model. + prefix = self.XL_PREFIX - inputs = self._parse_and_tokenize( - padding_text + prompt_text, padding=False, add_special_tokens=False - ) - else: - inputs = self._parse_and_tokenize(prompt_text, padding=False, add_special_tokens=False) + if prefix: + prefix_inputs = self._parse_and_tokenize(prefix, padding=False, add_special_tokens=False) + # This impacts max_length and min_length argument that need adjusting. + prefix_length = prefix_inputs["input_ids"].shape[-1] + if generate_kwargs.get("max_length", None) is not None: + generate_kwargs["max_length"] += prefix_length + if generate_kwargs.get("min_length", None) is not None: + generate_kwargs["min_length"] += prefix_length + + prefix = prefix or "" + inputs = self._parse_and_tokenize(prefix + prompt_text, padding=False, add_special_tokens=False) # set input_ids to None to allow empty prompt if inputs["input_ids"].shape[-1] == 0: diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index f475e89bd..7551350c4 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -424,12 +424,14 @@ class MonoColumnInputTestCase(unittest.TestCase): for model_name in TEXT_GENERATION_FINETUNED_MODELS: nlp = pipeline(task="text-generation", model=model_name, tokenizer=model_name, framework="pt") self._test_mono_column_pipeline(nlp, VALID_INPUTS, {}) + self._test_mono_column_pipeline(nlp, VALID_INPUTS, {}, prefix="This is ") @require_tf def test_tf_text_generation(self): for model_name in TEXT_GENERATION_FINETUNED_MODELS: nlp = pipeline(task="text-generation", model=model_name, tokenizer=model_name, framework="tf") self._test_mono_column_pipeline(nlp, VALID_INPUTS, {}) + self._test_mono_column_pipeline(nlp, VALID_INPUTS, {}, prefix="This is ") @slow @require_torch