Output dicts support in text generation pipeline (#35092)

* Support for generate_argument: return_dict_in_generate=True, instead of returning a error

* fix: call test with return_dict_in_generate=True

* fix: Only import torch if it is present

* update: Encapsulate output_dict changes

* fix: added back original comments

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Jonas Rohw 2025-01-29 15:44:46 +01:00 committed by GitHub
parent cf90404807
commit 23d782ead2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 76 additions and 4 deletions

View file

@ -3,11 +3,13 @@ import itertools
import types
from typing import Dict
from ..utils import add_end_docstrings, is_tf_available, is_torch_available
from ..utils import ModelOutput, add_end_docstrings, is_tf_available, is_torch_available
from .base import Pipeline, build_pipeline_init_args
if is_torch_available():
import torch
from ..models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from .pt_utils import KeyDataset
@ -380,13 +382,44 @@ class TextGenerationPipeline(Pipeline):
if "generation_config" not in generate_kwargs:
generate_kwargs["generation_config"] = self.generation_config
generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
output = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
if isinstance(output, ModelOutput):
generated_sequence = output.sequences
other_outputs = {k: v for k, v in output.items() if k != "sequences"}
out_b = generated_sequence.shape[0]
if self.framework == "pt":
for key, value in other_outputs.items():
if isinstance(value, torch.Tensor) and value.shape[0] == out_b:
other_outputs[key] = value.reshape(in_b, out_b // in_b, *value.shape[1:])
if isinstance(value, tuple) and len(value[0]) == out_b:
value = torch.stack(value).swapaxes(0, 1)
other_outputs[key] = value
elif self.framework == "tf":
for key, value in other_outputs.items():
if isinstance(value, tf.Tensor) and value.shape[0] == out_b:
other_outputs[key] = tf.reshape(value, (in_b, out_b // in_b, *value.shape[1:]))
if isinstance(value, tuple) and len(value[0]) == out_b:
value = tf.stack(value).swapaxes(0, 1)
other_outputs[key] = value
else:
generated_sequence = output
other_outputs = {}
out_b = generated_sequence.shape[0]
if self.framework == "pt":
generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
elif self.framework == "tf":
generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}
model_outputs = {
"generated_sequence": generated_sequence,
"input_ids": input_ids,
"prompt_text": prompt_text,
}
model_outputs.update(other_outputs)
return model_outputs
def postprocess(
self,
@ -400,7 +433,19 @@ class TextGenerationPipeline(Pipeline):
prompt_text = model_outputs["prompt_text"]
generated_sequence = generated_sequence.numpy().tolist()
records = []
for sequence in generated_sequence:
other_outputs = model_outputs.get("additional_outputs", {})
splitted_keys = {}
if other_outputs:
if self.framework == "pt":
for k, v in other_outputs.items():
if isinstance(v, torch.Tensor) and v.shape[0] == len(generated_sequence):
splitted_keys[k] = v.numpy().tolist()
elif self.framework == "tf":
for k, v in other_outputs.items():
if isinstance(v, tf.Tensor) and v.shape[0] == len(generated_sequence):
splitted_keys[k] = v.numpy().tolist()
for idx, sequence in enumerate(generated_sequence):
if return_type == ReturnType.TENSORS:
record = {"generated_token_ids": sequence}
elif return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}:
@ -444,6 +489,8 @@ class TextGenerationPipeline(Pipeline):
# When we're not starting from a prefill, the output is a new assistant message
all_text = list(prompt_text.messages) + [{"role": "assistant", "content": all_text}]
record = {"generated_text": all_text}
for key, values in splitted_keys.items():
record[key] = values[idx]
records.append(record)
return records

View file

@ -653,6 +653,31 @@ class TextGenerationPipelineTests(unittest.TestCase):
_ = text_generator(prompt, max_length=10)
self.assertNotIn(logger_msg, cl.out)
def test_return_dict_in_generate(self):
text_generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-gpt2", max_new_tokens=16)
out = text_generator(
["This is great !", "Something else"], return_dict_in_generate=True, output_logits=True, output_scores=True
)
self.assertEqual(
out,
[
[
{
"generated_text": ANY(str),
"logits": ANY(list),
"scores": ANY(list),
},
],
[
{
"generated_text": ANY(str),
"logits": ANY(list),
"scores": ANY(list),
},
],
],
)
@require_torch
def test_pipeline_assisted_generation(self):
"""Tests that we can run assisted generation in the pipeline"""