diff --git a/src/transformers/generation/streamers.py b/src/transformers/generation/streamers.py index 719c067d2..4006256f8 100644 --- a/src/transformers/generation/streamers.py +++ b/src/transformers/generation/streamers.py @@ -14,7 +14,7 @@ # limitations under the License. from queue import Queue -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: @@ -133,6 +133,9 @@ class TextIteratorStreamer(TextStreamer): The tokenized used to decode the tokens. skip_prompt (`bool`, *optional*, defaults to `False`): Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots. + timeout (`float`, *optional*): + The timeout for the text queue. If `None`, the queue will block indefinitely. Useful to handle exceptions + in `.generate()`, when it is called in a separate thread. decode_kwargs (`dict`, *optional*): Additional keyword arguments to pass to the tokenizer's `decode` method. @@ -159,22 +162,25 @@ class TextIteratorStreamer(TextStreamer): ``` """ - def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs): + def __init__( + self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, timeout: Optional[float] = None, **decode_kwargs + ): super().__init__(tokenizer, skip_prompt, **decode_kwargs) self.text_queue = Queue() self.stop_signal = None + self.timeout = timeout def on_finalized_text(self, text: str, stream_end: bool = False): """Put the new text in the queue. If the stream is ending, also put a stop signal in the queue.""" - self.text_queue.put(text) + self.text_queue.put(text, timeout=self.timeout) if stream_end: - self.text_queue.put(self.stop_signal) + self.text_queue.put(self.stop_signal, timeout=self.timeout) def __iter__(self): return self def __next__(self): - value = self.text_queue.get() + value = self.text_queue.get(timeout=self.timeout) if value == self.stop_signal: raise StopIteration() else: diff --git a/tests/generation/test_streamers.py b/tests/generation/test_streamers.py index 7214e56cd..361f39e03 100644 --- a/tests/generation/test_streamers.py +++ b/tests/generation/test_streamers.py @@ -14,6 +14,7 @@ # limitations under the License. import unittest +from queue import Empty from threading import Thread from transformers import AutoTokenizer, TextIteratorStreamer, TextStreamer, is_torch_available @@ -102,3 +103,20 @@ class StreamerTester(unittest.TestCase): streamer_text = cs.out[:-1] # Remove the final "\n" streamer_text_tokenized = tokenizer(streamer_text, return_tensors="pt") self.assertEqual(streamer_text_tokenized.input_ids.shape, (1, 1)) + + def test_iterator_streamer_timeout(self): + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + model.config.eos_token_id = -1 + + input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device) + streamer = TextIteratorStreamer(tokenizer, timeout=0.001) + generation_kwargs = {"input_ids": input_ids, "max_new_tokens": 10, "do_sample": False, "streamer": streamer} + thread = Thread(target=model.generate, kwargs=generation_kwargs) + thread.start() + + # The streamer will timeout after 0.001 seconds, so an exception will be raised + with self.assertRaises(Empty): + streamer_text = "" + for new_text in streamer: + streamer_text += new_text