mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-15 21:01:19 +00:00
[WHISPER] Update modeling tests (#20162)
* Update modeling tests * update tokenization test * typo * nit * fix expected attention outputs * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update tests from review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> * remove problematics kwargs passed to the padding function Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
f60eec4003
commit
11b2e45ccc
4 changed files with 11 additions and 8 deletions
|
|
@ -307,7 +307,6 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
|||
max_length=max_length if max_length else self.n_samples,
|
||||
truncation=truncation,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
**kwargs,
|
||||
)
|
||||
# make sure list is in array format
|
||||
input_features = padded_inputs.get("input_features").transpose(2, 0, 1)
|
||||
|
|
|
|||
|
|
@ -650,13 +650,15 @@ def _test_large_logits_librispeech(in_queue, out_queue, timeout):
|
|||
input_speech = _load_datasamples(1)
|
||||
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
|
||||
processed_inputs = processor(audio=input_speech, text="This part of the speech", return_tensors="tf")
|
||||
processed_inputs = processor(
|
||||
audio=input_speech, text="This part of the speech", add_special_tokens=False, return_tensors="tf"
|
||||
)
|
||||
input_features = processed_inputs.input_features
|
||||
labels = processed_inputs.labels
|
||||
decoder_input_ids = processed_inputs.labels
|
||||
|
||||
logits = model(
|
||||
input_features,
|
||||
decoder_input_ids=labels,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
output_hidden_states=False,
|
||||
output_attentions=False,
|
||||
use_cache=False,
|
||||
|
|
|
|||
|
|
@ -853,13 +853,15 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||
input_speech = self._load_datasamples(1)
|
||||
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
|
||||
processed_inputs = processor(audio=input_speech, text="This part of the speech", return_tensors="pt")
|
||||
processed_inputs = processor(
|
||||
audio=input_speech, text="This part of the speech", add_special_tokens=False, return_tensors="pt"
|
||||
)
|
||||
input_features = processed_inputs.input_features.to(torch_device)
|
||||
labels = processed_inputs.labels.to(torch_device)
|
||||
decoder_input_ids = processed_inputs.labels.to(torch_device)
|
||||
|
||||
logits = model(
|
||||
input_features,
|
||||
decoder_input_ids=labels,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
output_hidden_states=False,
|
||||
output_attentions=False,
|
||||
use_cache=False,
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
@slow
|
||||
def test_tokenizer_integration(self):
|
||||
# fmt: off
|
||||
expected_encoding = {'input_ids': [[41762, 364, 357, 36234, 1900, 355, 12972, 13165, 354, 12, 35636, 364, 290, 12972, 13165, 354, 12, 5310, 13363, 12, 4835, 8, 3769, 2276, 12, 29983, 45619, 357, 13246, 51, 11, 402, 11571, 12, 17, 11, 5564, 13246, 38586, 11, 16276, 44, 11, 4307, 346, 33, 861, 11, 16276, 7934, 23029, 329, 12068, 15417, 28491, 357, 32572, 52, 8, 290, 12068, 15417, 16588, 357, 32572, 38, 8, 351, 625, 3933, 10, 2181, 13363, 4981, 287, 1802, 10, 8950, 290, 2769, 48817, 1799, 1022, 449, 897, 11, 9485, 15884, 354, 290, 309, 22854, 37535, 13], [13246, 51, 318, 3562, 284, 662, 12, 27432, 2769, 8406, 4154, 282, 24612, 422, 9642, 9608, 276, 2420, 416, 26913, 21143, 319, 1111, 1364, 290, 826, 4732, 287, 477, 11685, 13], [464, 2068, 7586, 21831, 18045, 625, 262, 16931, 3290, 13]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]} # noqa: E501
|
||||
expected_encoding = {'input_ids': [[50257, 50362, 41762, 364, 357, 36234, 1900, 355, 12972, 13165, 354, 12, 35636, 364, 290, 12972, 13165, 354, 12, 5310, 13363, 12, 4835, 8, 3769, 2276, 12, 29983, 45619, 357, 13246, 51, 11, 402, 11571, 12, 17, 11, 5564, 13246, 38586, 11, 16276, 44, 11, 4307, 346, 33, 861, 11, 16276, 7934, 23029, 329, 12068, 15417, 28491, 357, 32572, 52, 8, 290, 12068, 15417, 16588, 357, 32572, 38, 8, 351, 625, 3933, 10, 2181, 13363, 4981, 287, 1802, 10, 8950, 290, 2769, 48817, 1799, 1022, 449, 897, 11, 9485, 15884, 354, 290, 309, 22854, 37535, 13, 50256], [50257, 50362, 13246, 51, 318, 3562, 284, 662, 12, 27432, 2769, 8406, 4154, 282, 24612, 422, 9642, 9608, 276, 2420, 416, 26913, 21143, 319, 1111, 1364, 290, 826, 4732, 287, 477, 11685, 13, 50256], [50257, 50362, 464, 2068, 7586, 21831, 18045, 625, 262, 16931, 3290, 13, 50256]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]} # noqa: E501
|
||||
# fmt: on
|
||||
|
||||
self.tokenizer_integration_test_util(
|
||||
|
|
|
|||
Loading…
Reference in a new issue