mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Output dicts support in text generation pipeline (#35092)
* Support for generate_argument: return_dict_in_generate=True, instead of returning a error * fix: call test with return_dict_in_generate=True * fix: Only import torch if it is present * update: Encapsulate output_dict changes * fix: added back original comments --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
parent
cf90404807
commit
23d782ead2
2 changed files with 76 additions and 4 deletions
|
|
@ -3,11 +3,13 @@ import itertools
|
|||
import types
|
||||
from typing import Dict
|
||||
|
||||
from ..utils import add_end_docstrings, is_tf_available, is_torch_available
|
||||
from ..utils import ModelOutput, add_end_docstrings, is_tf_available, is_torch_available
|
||||
from .base import Pipeline, build_pipeline_init_args
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from ..models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||
from .pt_utils import KeyDataset
|
||||
|
||||
|
|
@ -380,13 +382,44 @@ class TextGenerationPipeline(Pipeline):
|
|||
if "generation_config" not in generate_kwargs:
|
||||
generate_kwargs["generation_config"] = self.generation_config
|
||||
|
||||
generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
|
||||
output = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
|
||||
|
||||
if isinstance(output, ModelOutput):
|
||||
generated_sequence = output.sequences
|
||||
other_outputs = {k: v for k, v in output.items() if k != "sequences"}
|
||||
out_b = generated_sequence.shape[0]
|
||||
|
||||
if self.framework == "pt":
|
||||
for key, value in other_outputs.items():
|
||||
if isinstance(value, torch.Tensor) and value.shape[0] == out_b:
|
||||
other_outputs[key] = value.reshape(in_b, out_b // in_b, *value.shape[1:])
|
||||
if isinstance(value, tuple) and len(value[0]) == out_b:
|
||||
value = torch.stack(value).swapaxes(0, 1)
|
||||
other_outputs[key] = value
|
||||
elif self.framework == "tf":
|
||||
for key, value in other_outputs.items():
|
||||
if isinstance(value, tf.Tensor) and value.shape[0] == out_b:
|
||||
other_outputs[key] = tf.reshape(value, (in_b, out_b // in_b, *value.shape[1:]))
|
||||
if isinstance(value, tuple) and len(value[0]) == out_b:
|
||||
value = tf.stack(value).swapaxes(0, 1)
|
||||
other_outputs[key] = value
|
||||
else:
|
||||
generated_sequence = output
|
||||
other_outputs = {}
|
||||
|
||||
out_b = generated_sequence.shape[0]
|
||||
if self.framework == "pt":
|
||||
generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
|
||||
elif self.framework == "tf":
|
||||
generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
|
||||
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}
|
||||
|
||||
model_outputs = {
|
||||
"generated_sequence": generated_sequence,
|
||||
"input_ids": input_ids,
|
||||
"prompt_text": prompt_text,
|
||||
}
|
||||
model_outputs.update(other_outputs)
|
||||
return model_outputs
|
||||
|
||||
def postprocess(
|
||||
self,
|
||||
|
|
@ -400,7 +433,19 @@ class TextGenerationPipeline(Pipeline):
|
|||
prompt_text = model_outputs["prompt_text"]
|
||||
generated_sequence = generated_sequence.numpy().tolist()
|
||||
records = []
|
||||
for sequence in generated_sequence:
|
||||
other_outputs = model_outputs.get("additional_outputs", {})
|
||||
splitted_keys = {}
|
||||
if other_outputs:
|
||||
if self.framework == "pt":
|
||||
for k, v in other_outputs.items():
|
||||
if isinstance(v, torch.Tensor) and v.shape[0] == len(generated_sequence):
|
||||
splitted_keys[k] = v.numpy().tolist()
|
||||
elif self.framework == "tf":
|
||||
for k, v in other_outputs.items():
|
||||
if isinstance(v, tf.Tensor) and v.shape[0] == len(generated_sequence):
|
||||
splitted_keys[k] = v.numpy().tolist()
|
||||
|
||||
for idx, sequence in enumerate(generated_sequence):
|
||||
if return_type == ReturnType.TENSORS:
|
||||
record = {"generated_token_ids": sequence}
|
||||
elif return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}:
|
||||
|
|
@ -444,6 +489,8 @@ class TextGenerationPipeline(Pipeline):
|
|||
# When we're not starting from a prefill, the output is a new assistant message
|
||||
all_text = list(prompt_text.messages) + [{"role": "assistant", "content": all_text}]
|
||||
record = {"generated_text": all_text}
|
||||
for key, values in splitted_keys.items():
|
||||
record[key] = values[idx]
|
||||
records.append(record)
|
||||
|
||||
return records
|
||||
|
|
|
|||
|
|
@ -653,6 +653,31 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
|||
_ = text_generator(prompt, max_length=10)
|
||||
self.assertNotIn(logger_msg, cl.out)
|
||||
|
||||
def test_return_dict_in_generate(self):
|
||||
text_generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-gpt2", max_new_tokens=16)
|
||||
out = text_generator(
|
||||
["This is great !", "Something else"], return_dict_in_generate=True, output_logits=True, output_scores=True
|
||||
)
|
||||
self.assertEqual(
|
||||
out,
|
||||
[
|
||||
[
|
||||
{
|
||||
"generated_text": ANY(str),
|
||||
"logits": ANY(list),
|
||||
"scores": ANY(list),
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
"generated_text": ANY(str),
|
||||
"logits": ANY(list),
|
||||
"scores": ANY(list),
|
||||
},
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_pipeline_assisted_generation(self):
|
||||
"""Tests that we can run assisted generation in the pipeline"""
|
||||
|
|
|
|||
Loading…
Reference in a new issue