mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Generate: TextIteratorStreamer timeout (#22576)
This commit is contained in:
parent
11fd2c773b
commit
861ff890d6
2 changed files with 29 additions and 5 deletions
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
from queue import Queue
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -133,6 +133,9 @@ class TextIteratorStreamer(TextStreamer):
|
|||
The tokenized used to decode the tokens.
|
||||
skip_prompt (`bool`, *optional*, defaults to `False`):
|
||||
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
|
||||
timeout (`float`, *optional*):
|
||||
The timeout for the text queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
|
||||
in `.generate()`, when it is called in a separate thread.
|
||||
decode_kwargs (`dict`, *optional*):
|
||||
Additional keyword arguments to pass to the tokenizer's `decode` method.
|
||||
|
||||
|
|
@ -159,22 +162,25 @@ class TextIteratorStreamer(TextStreamer):
|
|||
```
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
|
||||
def __init__(
|
||||
self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, timeout: Optional[float] = None, **decode_kwargs
|
||||
):
|
||||
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
|
||||
self.text_queue = Queue()
|
||||
self.stop_signal = None
|
||||
self.timeout = timeout
|
||||
|
||||
def on_finalized_text(self, text: str, stream_end: bool = False):
|
||||
"""Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
|
||||
self.text_queue.put(text)
|
||||
self.text_queue.put(text, timeout=self.timeout)
|
||||
if stream_end:
|
||||
self.text_queue.put(self.stop_signal)
|
||||
self.text_queue.put(self.stop_signal, timeout=self.timeout)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
value = self.text_queue.get()
|
||||
value = self.text_queue.get(timeout=self.timeout)
|
||||
if value == self.stop_signal:
|
||||
raise StopIteration()
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from queue import Empty
|
||||
from threading import Thread
|
||||
|
||||
from transformers import AutoTokenizer, TextIteratorStreamer, TextStreamer, is_torch_available
|
||||
|
|
@ -102,3 +103,20 @@ class StreamerTester(unittest.TestCase):
|
|||
streamer_text = cs.out[:-1] # Remove the final "\n"
|
||||
streamer_text_tokenized = tokenizer(streamer_text, return_tensors="pt")
|
||||
self.assertEqual(streamer_text_tokenized.input_ids.shape, (1, 1))
|
||||
|
||||
def test_iterator_streamer_timeout(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
model.config.eos_token_id = -1
|
||||
|
||||
input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
|
||||
streamer = TextIteratorStreamer(tokenizer, timeout=0.001)
|
||||
generation_kwargs = {"input_ids": input_ids, "max_new_tokens": 10, "do_sample": False, "streamer": streamer}
|
||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
|
||||
# The streamer will timeout after 0.001 seconds, so an exception will be raised
|
||||
with self.assertRaises(Empty):
|
||||
streamer_text = ""
|
||||
for new_text in streamer:
|
||||
streamer_text += new_text
|
||||
|
|
|
|||
Loading…
Reference in a new issue