From 11b2e45ccc569f8ef8a8e4dfcb4c80af7fee0335 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Tue, 15 Nov 2022 11:04:58 +0100 Subject: [PATCH] [WHISPER] Update modeling tests (#20162) * Update modeling tests * update tokenization test * typo * nit * fix expected attention outputs * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update tests from review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: ydshieh * remove problematics kwargs passed to the padding function Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: ydshieh --- .../models/whisper/feature_extraction_whisper.py | 1 - tests/models/whisper/test_modeling_tf_whisper.py | 8 +++++--- tests/models/whisper/test_modeling_whisper.py | 8 +++++--- tests/models/whisper/test_tokenization_whisper.py | 2 +- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/whisper/feature_extraction_whisper.py b/src/transformers/models/whisper/feature_extraction_whisper.py index 2640a2925..1c1a8d369 100644 --- a/src/transformers/models/whisper/feature_extraction_whisper.py +++ b/src/transformers/models/whisper/feature_extraction_whisper.py @@ -307,7 +307,6 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): max_length=max_length if max_length else self.n_samples, truncation=truncation, pad_to_multiple_of=pad_to_multiple_of, - **kwargs, ) # make sure list is in array format input_features = padded_inputs.get("input_features").transpose(2, 0, 1) diff --git a/tests/models/whisper/test_modeling_tf_whisper.py b/tests/models/whisper/test_modeling_tf_whisper.py index ae99c9408..a82f396d8 100644 --- a/tests/models/whisper/test_modeling_tf_whisper.py +++ b/tests/models/whisper/test_modeling_tf_whisper.py @@ -650,13 +650,15 @@ def _test_large_logits_librispeech(in_queue, out_queue, timeout): input_speech = _load_datasamples(1) processor = WhisperProcessor.from_pretrained("openai/whisper-large") - processed_inputs = processor(audio=input_speech, text="This part of the speech", return_tensors="tf") + processed_inputs = processor( + audio=input_speech, text="This part of the speech", add_special_tokens=False, return_tensors="tf" + ) input_features = processed_inputs.input_features - labels = processed_inputs.labels + decoder_input_ids = processed_inputs.labels logits = model( input_features, - decoder_input_ids=labels, + decoder_input_ids=decoder_input_ids, output_hidden_states=False, output_attentions=False, use_cache=False, diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index c629b10bf..d85f5ca54 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -853,13 +853,15 @@ class WhisperModelIntegrationTests(unittest.TestCase): input_speech = self._load_datasamples(1) processor = WhisperProcessor.from_pretrained("openai/whisper-large") - processed_inputs = processor(audio=input_speech, text="This part of the speech", return_tensors="pt") + processed_inputs = processor( + audio=input_speech, text="This part of the speech", add_special_tokens=False, return_tensors="pt" + ) input_features = processed_inputs.input_features.to(torch_device) - labels = processed_inputs.labels.to(torch_device) + decoder_input_ids = processed_inputs.labels.to(torch_device) logits = model( input_features, - decoder_input_ids=labels, + decoder_input_ids=decoder_input_ids, output_hidden_states=False, output_attentions=False, use_cache=False, diff --git a/tests/models/whisper/test_tokenization_whisper.py b/tests/models/whisper/test_tokenization_whisper.py index d01c41c0a..272df8e33 100644 --- a/tests/models/whisper/test_tokenization_whisper.py +++ b/tests/models/whisper/test_tokenization_whisper.py @@ -96,7 +96,7 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase): @slow def test_tokenizer_integration(self): # fmt: off - expected_encoding = {'input_ids': [[41762, 364, 357, 36234, 1900, 355, 12972, 13165, 354, 12, 35636, 364, 290, 12972, 13165, 354, 12, 5310, 13363, 12, 4835, 8, 3769, 2276, 12, 29983, 45619, 357, 13246, 51, 11, 402, 11571, 12, 17, 11, 5564, 13246, 38586, 11, 16276, 44, 11, 4307, 346, 33, 861, 11, 16276, 7934, 23029, 329, 12068, 15417, 28491, 357, 32572, 52, 8, 290, 12068, 15417, 16588, 357, 32572, 38, 8, 351, 625, 3933, 10, 2181, 13363, 4981, 287, 1802, 10, 8950, 290, 2769, 48817, 1799, 1022, 449, 897, 11, 9485, 15884, 354, 290, 309, 22854, 37535, 13], [13246, 51, 318, 3562, 284, 662, 12, 27432, 2769, 8406, 4154, 282, 24612, 422, 9642, 9608, 276, 2420, 416, 26913, 21143, 319, 1111, 1364, 290, 826, 4732, 287, 477, 11685, 13], [464, 2068, 7586, 21831, 18045, 625, 262, 16931, 3290, 13]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]} # noqa: E501 + expected_encoding = {'input_ids': [[50257, 50362, 41762, 364, 357, 36234, 1900, 355, 12972, 13165, 354, 12, 35636, 364, 290, 12972, 13165, 354, 12, 5310, 13363, 12, 4835, 8, 3769, 2276, 12, 29983, 45619, 357, 13246, 51, 11, 402, 11571, 12, 17, 11, 5564, 13246, 38586, 11, 16276, 44, 11, 4307, 346, 33, 861, 11, 16276, 7934, 23029, 329, 12068, 15417, 28491, 357, 32572, 52, 8, 290, 12068, 15417, 16588, 357, 32572, 38, 8, 351, 625, 3933, 10, 2181, 13363, 4981, 287, 1802, 10, 8950, 290, 2769, 48817, 1799, 1022, 449, 897, 11, 9485, 15884, 354, 290, 309, 22854, 37535, 13, 50256], [50257, 50362, 13246, 51, 318, 3562, 284, 662, 12, 27432, 2769, 8406, 4154, 282, 24612, 422, 9642, 9608, 276, 2420, 416, 26913, 21143, 319, 1111, 1364, 290, 826, 4732, 287, 477, 11685, 13, 50256], [50257, 50362, 464, 2068, 7586, 21831, 18045, 625, 262, 16931, 3290, 13, 50256]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]} # noqa: E501 # fmt: on self.tokenizer_integration_test_util(