From 7f79a97399bb52aad8460e1da2f36577d5dccfed Mon Sep 17 00:00:00 2001 From: Aviv Shamsian Date: Fri, 12 Jul 2024 22:07:10 +0300 Subject: [PATCH] fix prompt strip to support tensors and np arrays (#27818) * fix prompt strip to support tensors and np arrays * framework agnostic * change logic check before converting prompt into list Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * adding _convert_to_list to tokenization_whisper_fast * adding tests for prompt decoding * adding comment Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * adding comment Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * revert minor * make style formatting * style formatting after update * Update src/transformers/models/whisper/tokenization_whisper_fast.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * fixing _strip_prompt to handle _decode_with_timestamps * fix copies --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> --- .../models/whisper/tokenization_whisper.py | 23 ++++++++++-- .../whisper/tokenization_whisper_fast.py | 24 +++++++++++-- .../whisper/test_tokenization_whisper.py | 35 +++++++++++++++++++ 3 files changed, 76 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 303822de6..aea9dd289 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -851,9 +851,16 @@ class WhisperTokenizer(PreTrainedTokenizer): batch_encoding.convert_to_tensors(tensor_type=return_tensors) return batch_encoding["input_ids"] - @staticmethod - def _strip_prompt(token_ids: List[int], prompt_token_id: int, decoder_start_token_id: int): - has_prompt = isinstance(token_ids, list) and token_ids and token_ids[0] == prompt_token_id + def _strip_prompt(self, token_ids: List[int], prompt_token_id: int, decoder_start_token_id: int): + if not isinstance(token_ids, list): + token_ids = self._convert_to_list(token_ids) + + # handle case of empty token_ids for decoding with timestamps. + # at this point token_ids is a list, so it is safe to use if not check. + if not token_ids: + return token_ids + + has_prompt = token_ids[0] == prompt_token_id if has_prompt: if decoder_start_token_id in token_ids: return token_ids[token_ids.index(decoder_start_token_id) :] @@ -862,6 +869,16 @@ class WhisperTokenizer(PreTrainedTokenizer): return token_ids + @staticmethod + def _convert_to_list(token_ids): + # convert type to ndarray if necessary + if "torch" in str(type(token_ids)) or "tensorflow" in str(type(token_ids)) and hasattr(token_ids, "numpy"): + token_ids = token_ids.numpy() + # now the token ids are either a numpy array, or a list of lists + if isinstance(token_ids, np.ndarray): + token_ids = token_ids.tolist() + return token_ids + def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, time_precision): """ diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py index a9e57cca7..d1dee3826 100644 --- a/src/transformers/models/whisper/tokenization_whisper_fast.py +++ b/src/transformers/models/whisper/tokenization_whisper_fast.py @@ -582,10 +582,17 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): batch_encoding.convert_to_tensors(tensor_type=return_tensors) return batch_encoding["input_ids"] - @staticmethod # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._strip_prompt - def _strip_prompt(token_ids: List[int], prompt_token_id: int, decoder_start_token_id: int): - has_prompt = isinstance(token_ids, list) and token_ids and token_ids[0] == prompt_token_id + def _strip_prompt(self, token_ids: List[int], prompt_token_id: int, decoder_start_token_id: int): + if not isinstance(token_ids, list): + token_ids = self._convert_to_list(token_ids) + + # handle case of empty token_ids for decoding with timestamps. + # at this point token_ids is a list, so it is safe to use if not check. + if not token_ids: + return token_ids + + has_prompt = token_ids[0] == prompt_token_id if has_prompt: if decoder_start_token_id in token_ids: return token_ids[token_ids.index(decoder_start_token_id) :] @@ -593,3 +600,14 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): return [] return token_ids + + @staticmethod + # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._convert_to_list + def _convert_to_list(token_ids): + # convert type to ndarray if necessary + if "torch" in str(type(token_ids)) or "tensorflow" in str(type(token_ids)) and hasattr(token_ids, "numpy"): + token_ids = token_ids.numpy() + # now the token ids are either a numpy array, or a list of lists + if isinstance(token_ids, np.ndarray): + token_ids = token_ids.tolist() + return token_ids diff --git a/tests/models/whisper/test_tokenization_whisper.py b/tests/models/whisper/test_tokenization_whisper.py index 41598cf27..530e23351 100644 --- a/tests/models/whisper/test_tokenization_whisper.py +++ b/tests/models/whisper/test_tokenization_whisper.py @@ -14,6 +14,8 @@ import unittest +import numpy as np + from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence from transformers.testing_utils import slow @@ -251,6 +253,39 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase): self.assertListEqual(tokenizer_prompt_ids.tolist(), fast_tokenizer_prompt_ids.tolist()) + def test_tokenizer_decode_prompt(self): + prompt_text = "What does the fox say?" + input_text = "Hatee hatee hatee ho" + + tokenizer = self.get_tokenizer() + rust_tokenizer = self.get_rust_tokenizer() + + # encode prompt and input text using tokenizer + prompt_ids = tokenizer.get_prompt_ids(prompt_text, return_tensors="np") + input_ids = tokenizer(input_text, return_tensors="np").input_ids[0] + input_ids = np.hstack([prompt_ids, input_ids]) + + # encode using fast tokenizer + rust_prompt_ids = rust_tokenizer.get_prompt_ids(prompt_text, return_tensors="np") + rust_input_ids = rust_tokenizer(input_text, return_tensors="np").input_ids[0] + rust_input_ids = np.hstack([rust_prompt_ids, rust_input_ids]) + + # check with prompt in output + pred_text = tokenizer.decode(input_ids, skip_special_tokens=False) + rust_pred_text = rust_tokenizer.decode(rust_input_ids, skip_special_tokens=False) + + # check correctness for both tokenizers + expected_text = f"<|startofprev|> {prompt_text}<|startoftranscript|><|notimestamps|>{input_text}<|endoftext|>" + self.assertEqual(pred_text.strip(), expected_text) + self.assertEqual(rust_pred_text.strip(), expected_text) + + # check stripping prompt from output + pred_text = tokenizer.decode(input_ids, skip_special_tokens=True) + rust_pred_text = tokenizer.decode(input_ids, skip_special_tokens=True) + + self.assertEqual(pred_text.strip(), input_text) + self.assertEqual(rust_pred_text.strip(), input_text) + def test_combine_tokens_into_words(self): tokenizer = self.get_tokenizer() rust_tokenizer = self.get_rust_tokenizer()