From e9b4800dda26423cb9ddebac41c5a3753afa2012 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Thu, 19 Jan 2023 16:25:56 +0100 Subject: [PATCH] [Whisper] Fix timestamp processor (#21187) * add draft logit processor * add template functions * update timesapmt processor parameters * draft script * simplify code * cleanup * fixup and clean * update pipeline * style * clean up previous idea * add tokenization utils * update tokenizer and asr output * fit whisper type * style and update test * clean test * style test * update tests * update error test * udpate code (not based on review yet) * update tokenization * update asr pipeline * update code * cleanup and update test * fmt * remove text verificatino * cleanup * cleanup * add model test * update tests * update code add docstring * update code and add docstring * fix pipeline tests * add draft logit processor add template functions update timesapmt processor parameters draft script simplify code cleanup fixup and clean update pipeline style clean up previous idea add tokenization utils update tokenizer and asr output fit whisper type style and update test clean test style test update tests update error test udpate code (not based on review yet) update tokenization update asr pipeline update code cleanup and update test fmt remove text verificatino cleanup cleanup add model test update tests update code add docstring update code and add docstring fix pipeline tests * Small update. * Fixup. * Tmp. * More support. * Making `forced_decoder_ids` non mandatory for users to set. * update and fix first bug * properly process sequence right after merge if last * tofo * allow list inputs + compute begin index better * start adding tests * add the 3 edge cases * style * format sequences * fixup * update * update * style * test passes, edge cases should be good * update last value * remove Trie * update tests and expec ted values * handle bigger chunk_length * clean tests a bit * refactor chunk iter and clean pipeline * update tests * style * refactor chunk iter and clean pipeline * upade * resolve comments * Apply suggestions from code review Co-authored-by: Nicolas Patry * take stride right into account * update test expected values * Update code based on review Co-authored-by: sgugger * major refactor * add correct strides for tests * Update src/transformers/pipelines/automatic_speech_recognition.py * fix whisper timestamp test Co-authored-by: Nicolas Patry Co-authored-by: sgugger --- .../pipelines/automatic_speech_recognition.py | 77 ++++++++++--------- tests/models/whisper/test_modeling_whisper.py | 51 +++++++++--- ..._pipelines_automatic_speech_recognition.py | 14 ++-- 3 files changed, 87 insertions(+), 55 deletions(-) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index cd759db27..a41ba02f7 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -90,7 +90,6 @@ def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source # approximation of the token to time ratio : ~0.2seconds time_precision = feature_extractor.chunk_length / max_source_positions time = 0 - actual_offset = 0 for seq_idx, item in enumerate(sequences): sequence, stride = item if isinstance(sequence, list): @@ -101,75 +100,81 @@ def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source begin_idx = np.where(sequence == timestamp_begin)[0].item() if timestamp_begin in sequence else 0 sequence = sequence[begin_idx:] + timestamp_tokens = sequence >= timestamp_begin + consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1 + last_timestamp = np.where(timestamp_tokens)[0][-1] + consecutive = np.append(consecutive, last_timestamp) if last_timestamp not in consecutive else consecutive if seq_idx != 0: time -= stride_left + stride_right offset = int((time / feature_extractor.sampling_rate) / time_precision) - timestamp_tokens = np.where(sequence >= timestamp_begin)[0][1::2] - if len(timestamp_tokens) >= 1: - # if a big chunk lenght is used, we need to check all of the previous items + overlap_time = int((stride_left / feature_extractor.sampling_rate) / time_precision) + # relevant timestamps are in the overlapping part + relevant_timestamp = np.where(sequence[consecutive] >= timestamp_begin + overlap_time)[0] + if relevant_timestamp.shape[0] > 0: + relevant_timestamp = ( + consecutive[relevant_timestamp[0] - 1] if relevant_timestamp[0] > 0 else consecutive[0] + ) + # if a big stride is used, we need to check some of the previous items for the best overlap best_match = 0 sliced_sequence = [] for idx, previous_sequence in enumerate(reversed(items)): previous_tokens = previous_sequence[1:-1] + if previous_sequence[0] < (timestamp_begin + offset - overlap_time) and idx != 0: + break # the previous sequence is too far in the past if len(previous_tokens) > 0: + # find the longest common sequence between the overlapping parts index_left, index_right, match_length = _fast_find_longest_common_sequence( - sequence, previous_tokens + sequence[1:relevant_timestamp], previous_tokens ) # don't do anything if only 1 token was matched if match_length > 1 and match_length > best_match: best_match = match_length best_idx = idx end_of_curr_sequence_idx = ( - np.where(sequence[index_left:] >= timestamp_begin)[0][0] + 1 + index_left + np.where(sequence[index_left + 1 :] >= timestamp_begin)[0][0] + 1 ) - sliced_sequence = sequence[index_left:end_of_curr_sequence_idx] + end_of_curr_sequence_idx = end_of_curr_sequence_idx + 1 + index_left # if all the tokens are matched, suffix if index_left == 0 and match_length == len(previous_tokens): + sliced_sequence = np.insert( + sequence[index_left + 1 : end_of_curr_sequence_idx], 0, previous_sequence[0] + ) sliced_sequence[-1] = previous_sequence[-1] # if part of the previous sequence is not taken - elif index_left > 0: + elif index_left >= 0: + sliced_sequence = sequence[index_left + 1 : end_of_curr_sequence_idx] # let's insert the missing part of the previous sequence - sliced_sequence = np.insert(sliced_sequence, 0, previous_sequence[: index_right + 1]) + previous_slice = ( + previous_sequence[: index_right + 1] if index_right > 0 else [previous_sequence[0]] + ) + sliced_sequence = np.insert(sliced_sequence, 0, previous_slice) sliced_sequence[-1] += offset + if len(sliced_sequence) > 0: items[len(items) - best_idx - 1] = sliced_sequence items = items[: len(items) - best_idx] sequence = sequence[end_of_curr_sequence_idx:] - actual_offset = items[-1][-1] - timestamp_begin + # sequence might have changed timestamp_tokens = sequence >= timestamp_begin consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1 + if sum(timestamp_tokens) > 0: + last_timestamp = np.where(timestamp_tokens)[0][-1] + consecutive = ( + np.append(consecutive, last_timestamp + 1) if last_timestamp not in consecutive else consecutive + ) - if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens + if len(consecutive) > 0: last_slice = 0 - # take the last timestamp of the previous chunk for current_slice in consecutive: + actual_offset = items[-1][-1] if seq_idx != 0 or last_slice != 0 else sequence[0] sliced_tokens = sequence[last_slice:current_slice] - # set correct timestamps - sliced_tokens[0] += actual_offset - sliced_tokens[-1] += actual_offset - items.append(sliced_tokens) # correct final sequence - last_slice = current_slice - # check if we have a non consecutive timestamp at the end - if np.where(timestamp_tokens)[0][-1] != current_slice: - # offset = items[-1][-1] if len(items) > 0 else timestamp_begin - sliced_tokens = sequence[current_slice : np.where(timestamp_tokens)[0][-1] + 1] - sliced_tokens[0] += actual_offset - sliced_tokens[-1] += actual_offset + duration = sliced_tokens[-1] - sliced_tokens[0] + sliced_tokens[0] = actual_offset + sliced_tokens[-1] = actual_offset + duration items.append(sliced_tokens) - else: - timestamps = sequence[timestamp_tokens.nonzero()[0].flatten()] - if len(timestamps) > 0 and timestamps[-1].item() != timestamp_begin: - # no consecutive timestamps but it has a timestamp; use the last one. - # single timestamp at the end means no speech after the last timestamp. - last_idx = np.argwhere(sequence == timestamps[-1])[0][0] - sliced_sequence = sequence[: last_idx + 1] - duration = sliced_sequence[-1] - sliced_sequence[0] - # We need to discard the previous timing information - sliced_sequence[0] = items[-1][-1] - sliced_sequence[-1] = items[-1][-1] + duration - items.append(sliced_sequence) - # The beginning time of the next chunk + last_slice = current_slice + time += chunk_len result = [] for i in range(len(items)): diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index ee7481050..cf2cb1d96 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1086,26 +1086,53 @@ class WhisperModelIntegrationTests(unittest.TestCase): generated_ids = model.generate(input_features, max_length=448, logits_processor=timestamp_processor).to("cpu") # fmt: off - EXPECTED_OUTPUT = torch.tensor([50258, 50259, 50359, 50364, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404]) + EXPECTED_OUTPUT = torch.tensor([50258, 50259, 50359, 50364, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50692, 50692, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50926, 50926, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 51208, 51208, 949, 505, 11, 14138, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51552, 51552, 634, 575, 12525, 22618, 1968, 6144, 35617, 7354, 1292, 6, 589, 307, 534, 10281, 934, 439, 11, 293, 51836, 51836, 50257]) # fmt: on self.assertTrue(torch.allclose(generated_ids, EXPECTED_OUTPUT)) - # fmt: off EXPECTED_TRANSCRIPT = [ { - 'text': " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Layton's work is really Greek after all,", - 'offsets': [ - {'text': ' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.', 'timestamp': (0.0, 5.62)}, - {'text': " Nor is Mr. Quilter's manner less interesting than his matter.", 'timestamp': (5.62, 10.36)}, - {'text': ' He tells us that at this festive season of the year,', 'timestamp': (10.36, 14.46)}, - {'text': ' with Christmas and roast beef looming before us,', 'timestamp': (14.46, 17.76)}, - {'text': ' similes drawn from eating and its results occur most readily to the mind.', 'timestamp': (17.76, 22.8)}, - {'text': " He has grave doubts whether Sir Frederick Layton's work is really Greek after all,", 'timestamp': (22.8, 28.82)} - ] + "text": ( + " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is" + " Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season" + " of the year, with Christmas and roast beef looming before us, similarly drawn from eating and" + " its results occur most readily to the mind. He has grave doubts whether Sir Frederick Latins'" + " work is really Greek after all, and" + ), + "offsets": [ + { + "text": ( + " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." + ), + "timestamp": (0.0, 6.5600000000000005), + }, + { + "text": " Nor is Mr. Quilter's manner less interesting than his matter.", + "timestamp": (6.5600000000000005, 11.24), + }, + { + "text": ( + " He tells us that at this festive season of the year, with Christmas and roast beef" + " looming" + ), + "timestamp": (11.24, 16.88), + }, + { + "text": ( + " before us, similarly drawn from eating and its results occur most readily to the mind." + ), + "timestamp": (16.88, 23.76), + }, + { + "text": ( + " He has grave doubts whether Sir Frederick Latins' work is really Greek after all, and" + ), + "timestamp": (23.76, 29.44), + }, + ], } ] - # fmt: on transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True) self.assertEqual(transcript, EXPECTED_TRANSCRIPT) diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index 50d6759d7..dc304272f 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -344,7 +344,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel }, ) merge = _find_timestamp_sequence( - [[previous_sequence, (3000, 0, 0)], [next_sequences_1, (3000, 750, 0)]], + [[previous_sequence, (480_000, 0, 0)], [next_sequences_1, (480_000, 120_000, 0)]], processor.tokenizer, processor.feature_extractor, max_source_positions, @@ -381,7 +381,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel # fmt: on # {'text': ' of spectators, retrievality is not worth thinking about. His instant panic was followed by a small, sharp blow high on his chest.','timestamp': (0.0, 9.4)} merge = _find_timestamp_sequence( - [[previous_sequence, (3000, 0, 0)], [next_sequences_2, (3000, 750, 0)]], + [[previous_sequence, (480_000, 0, 0)], [next_sequences_2, (480_000, 120_000, 0)]], processor.tokenizer, processor.feature_extractor, max_source_positions, @@ -417,7 +417,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel # fmt: on # {'text': ' His instant panic was followed by a small, sharp blow high on his chest.','timestamp': (0.0, 9.4)} merge = _find_timestamp_sequence( - [[previous_sequence, (3000, 0, 0)], [next_sequences_3, (3000, 750, 0)]], + [[previous_sequence, (480_000, 0, 0)], [next_sequences_3, (480_000, 120_000, 0)]], processor.tokenizer, processor.feature_extractor, max_source_positions, @@ -447,11 +447,11 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel # last case is when the sequence is not in the first next predicted start and end of timestamp # fmt: off next_sequences_3 = [ - [50364, 2812, 9836, 14783, 390, 51492, 406, 3163, 1953, 466, 13, 50634, 50634, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 50934] + [50364, 2812, 9836, 14783, 390, 406, 3163, 1953, 466, 13, 50634, 50634, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 50934] ] # fmt: on merge = _find_timestamp_sequence( - [[previous_sequence, (3000, 0, 0)], [next_sequences_3, (3000, 750, 0)]], + [[previous_sequence, (480_000, 0, 0)], [next_sequences_3, (480_000, 167_000, 0)]], processor.tokenizer, processor.feature_extractor, max_source_positions, @@ -459,7 +459,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel # fmt: off self.assertEqual( merge, - [51492, 406, 3163, 1953, 466, 13, 53112, 53112, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 53332], + [51492, 406, 3163, 1953, 466, 13, 51612, 51612, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 51912] ) # fmt: on self.assertEqual( @@ -473,7 +473,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel {"text": " not worth thinking about.", "timestamp": (22.56, 24.96)}, { "text": " His instant panic was followed by a small, sharp blow high on his chest.", - "timestamp": (24.96, 29.36), + "timestamp": (24.96, 30.96), }, ], },