diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index c0f14663f..0a8e6e845 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -3,11 +3,13 @@ import itertools import types from typing import Dict -from ..utils import add_end_docstrings, is_tf_available, is_torch_available +from ..utils import ModelOutput, add_end_docstrings, is_tf_available, is_torch_available from .base import Pipeline, build_pipeline_init_args if is_torch_available(): + import torch + from ..models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from .pt_utils import KeyDataset @@ -380,13 +382,44 @@ class TextGenerationPipeline(Pipeline): if "generation_config" not in generate_kwargs: generate_kwargs["generation_config"] = self.generation_config - generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs) + output = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs) + + if isinstance(output, ModelOutput): + generated_sequence = output.sequences + other_outputs = {k: v for k, v in output.items() if k != "sequences"} + out_b = generated_sequence.shape[0] + + if self.framework == "pt": + for key, value in other_outputs.items(): + if isinstance(value, torch.Tensor) and value.shape[0] == out_b: + other_outputs[key] = value.reshape(in_b, out_b // in_b, *value.shape[1:]) + if isinstance(value, tuple) and len(value[0]) == out_b: + value = torch.stack(value).swapaxes(0, 1) + other_outputs[key] = value + elif self.framework == "tf": + for key, value in other_outputs.items(): + if isinstance(value, tf.Tensor) and value.shape[0] == out_b: + other_outputs[key] = tf.reshape(value, (in_b, out_b // in_b, *value.shape[1:])) + if isinstance(value, tuple) and len(value[0]) == out_b: + value = tf.stack(value).swapaxes(0, 1) + other_outputs[key] = value + else: + generated_sequence = output + other_outputs = {} + out_b = generated_sequence.shape[0] if self.framework == "pt": generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:]) elif self.framework == "tf": generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:])) - return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text} + + model_outputs = { + "generated_sequence": generated_sequence, + "input_ids": input_ids, + "prompt_text": prompt_text, + } + model_outputs.update(other_outputs) + return model_outputs def postprocess( self, @@ -400,7 +433,19 @@ class TextGenerationPipeline(Pipeline): prompt_text = model_outputs["prompt_text"] generated_sequence = generated_sequence.numpy().tolist() records = [] - for sequence in generated_sequence: + other_outputs = model_outputs.get("additional_outputs", {}) + splitted_keys = {} + if other_outputs: + if self.framework == "pt": + for k, v in other_outputs.items(): + if isinstance(v, torch.Tensor) and v.shape[0] == len(generated_sequence): + splitted_keys[k] = v.numpy().tolist() + elif self.framework == "tf": + for k, v in other_outputs.items(): + if isinstance(v, tf.Tensor) and v.shape[0] == len(generated_sequence): + splitted_keys[k] = v.numpy().tolist() + + for idx, sequence in enumerate(generated_sequence): if return_type == ReturnType.TENSORS: record = {"generated_token_ids": sequence} elif return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}: @@ -444,6 +489,8 @@ class TextGenerationPipeline(Pipeline): # When we're not starting from a prefill, the output is a new assistant message all_text = list(prompt_text.messages) + [{"role": "assistant", "content": all_text}] record = {"generated_text": all_text} + for key, values in splitted_keys.items(): + record[key] = values[idx] records.append(record) return records diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index 7504ae009..5c5d3de17 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -653,6 +653,31 @@ class TextGenerationPipelineTests(unittest.TestCase): _ = text_generator(prompt, max_length=10) self.assertNotIn(logger_msg, cl.out) + def test_return_dict_in_generate(self): + text_generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-gpt2", max_new_tokens=16) + out = text_generator( + ["This is great !", "Something else"], return_dict_in_generate=True, output_logits=True, output_scores=True + ) + self.assertEqual( + out, + [ + [ + { + "generated_text": ANY(str), + "logits": ANY(list), + "scores": ANY(list), + }, + ], + [ + { + "generated_text": ANY(str), + "logits": ANY(list), + "scores": ANY(list), + }, + ], + ], + ) + @require_torch def test_pipeline_assisted_generation(self): """Tests that we can run assisted generation in the pipeline"""