mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
[Whisper] Fix timestamp processor (#21187)
* add draft logit processor * add template functions * update timesapmt processor parameters * draft script * simplify code * cleanup * fixup and clean * update pipeline * style * clean up previous idea * add tokenization utils * update tokenizer and asr output * fit whisper type * style and update test * clean test * style test * update tests * update error test * udpate code (not based on review yet) * update tokenization * update asr pipeline * update code * cleanup and update test * fmt * remove text verificatino * cleanup * cleanup * add model test * update tests * update code add docstring * update code and add docstring * fix pipeline tests * add draft logit processor add template functions update timesapmt processor parameters draft script simplify code cleanup fixup and clean update pipeline style clean up previous idea add tokenization utils update tokenizer and asr output fit whisper type style and update test clean test style test update tests update error test udpate code (not based on review yet) update tokenization update asr pipeline update code cleanup and update test fmt remove text verificatino cleanup cleanup add model test update tests update code add docstring update code and add docstring fix pipeline tests * Small update. * Fixup. * Tmp. * More support. * Making `forced_decoder_ids` non mandatory for users to set. * update and fix first bug * properly process sequence right after merge if last * tofo * allow list inputs + compute begin index better * start adding tests * add the 3 edge cases * style * format sequences * fixup * update * update * style * test passes, edge cases should be good * update last value * remove Trie * update tests and expec ted values * handle bigger chunk_length * clean tests a bit * refactor chunk iter and clean pipeline * update tests * style * refactor chunk iter and clean pipeline * upade * resolve comments * Apply suggestions from code review Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> * take stride right into account * update test expected values * Update code based on review Co-authored-by: sgugger <sylvain.gugger@gmail.com> * major refactor * add correct strides for tests * Update src/transformers/pipelines/automatic_speech_recognition.py * fix whisper timestamp test Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> Co-authored-by: sgugger <sylvain.gugger@gmail.com>
This commit is contained in:
parent
9b42c68f7c
commit
e9b4800dda
3 changed files with 87 additions and 55 deletions
|
|
@ -90,7 +90,6 @@ def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source
|
|||
# approximation of the token to time ratio : ~0.2seconds
|
||||
time_precision = feature_extractor.chunk_length / max_source_positions
|
||||
time = 0
|
||||
actual_offset = 0
|
||||
for seq_idx, item in enumerate(sequences):
|
||||
sequence, stride = item
|
||||
if isinstance(sequence, list):
|
||||
|
|
@ -101,75 +100,81 @@ def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source
|
|||
begin_idx = np.where(sequence == timestamp_begin)[0].item() if timestamp_begin in sequence else 0
|
||||
sequence = sequence[begin_idx:]
|
||||
|
||||
timestamp_tokens = sequence >= timestamp_begin
|
||||
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
|
||||
last_timestamp = np.where(timestamp_tokens)[0][-1]
|
||||
consecutive = np.append(consecutive, last_timestamp) if last_timestamp not in consecutive else consecutive
|
||||
if seq_idx != 0:
|
||||
time -= stride_left + stride_right
|
||||
offset = int((time / feature_extractor.sampling_rate) / time_precision)
|
||||
timestamp_tokens = np.where(sequence >= timestamp_begin)[0][1::2]
|
||||
if len(timestamp_tokens) >= 1:
|
||||
# if a big chunk lenght is used, we need to check all of the previous items
|
||||
overlap_time = int((stride_left / feature_extractor.sampling_rate) / time_precision)
|
||||
# relevant timestamps are in the overlapping part
|
||||
relevant_timestamp = np.where(sequence[consecutive] >= timestamp_begin + overlap_time)[0]
|
||||
if relevant_timestamp.shape[0] > 0:
|
||||
relevant_timestamp = (
|
||||
consecutive[relevant_timestamp[0] - 1] if relevant_timestamp[0] > 0 else consecutive[0]
|
||||
)
|
||||
# if a big stride is used, we need to check some of the previous items for the best overlap
|
||||
best_match = 0
|
||||
sliced_sequence = []
|
||||
for idx, previous_sequence in enumerate(reversed(items)):
|
||||
previous_tokens = previous_sequence[1:-1]
|
||||
if previous_sequence[0] < (timestamp_begin + offset - overlap_time) and idx != 0:
|
||||
break # the previous sequence is too far in the past
|
||||
if len(previous_tokens) > 0:
|
||||
# find the longest common sequence between the overlapping parts
|
||||
index_left, index_right, match_length = _fast_find_longest_common_sequence(
|
||||
sequence, previous_tokens
|
||||
sequence[1:relevant_timestamp], previous_tokens
|
||||
)
|
||||
# don't do anything if only 1 token was matched
|
||||
if match_length > 1 and match_length > best_match:
|
||||
best_match = match_length
|
||||
best_idx = idx
|
||||
end_of_curr_sequence_idx = (
|
||||
np.where(sequence[index_left:] >= timestamp_begin)[0][0] + 1 + index_left
|
||||
np.where(sequence[index_left + 1 :] >= timestamp_begin)[0][0] + 1
|
||||
)
|
||||
sliced_sequence = sequence[index_left:end_of_curr_sequence_idx]
|
||||
end_of_curr_sequence_idx = end_of_curr_sequence_idx + 1 + index_left
|
||||
# if all the tokens are matched, suffix
|
||||
if index_left == 0 and match_length == len(previous_tokens):
|
||||
sliced_sequence = np.insert(
|
||||
sequence[index_left + 1 : end_of_curr_sequence_idx], 0, previous_sequence[0]
|
||||
)
|
||||
sliced_sequence[-1] = previous_sequence[-1]
|
||||
# if part of the previous sequence is not taken
|
||||
elif index_left > 0:
|
||||
elif index_left >= 0:
|
||||
sliced_sequence = sequence[index_left + 1 : end_of_curr_sequence_idx]
|
||||
# let's insert the missing part of the previous sequence
|
||||
sliced_sequence = np.insert(sliced_sequence, 0, previous_sequence[: index_right + 1])
|
||||
previous_slice = (
|
||||
previous_sequence[: index_right + 1] if index_right > 0 else [previous_sequence[0]]
|
||||
)
|
||||
sliced_sequence = np.insert(sliced_sequence, 0, previous_slice)
|
||||
sliced_sequence[-1] += offset
|
||||
|
||||
if len(sliced_sequence) > 0:
|
||||
items[len(items) - best_idx - 1] = sliced_sequence
|
||||
items = items[: len(items) - best_idx]
|
||||
sequence = sequence[end_of_curr_sequence_idx:]
|
||||
actual_offset = items[-1][-1] - timestamp_begin
|
||||
|
||||
# sequence might have changed
|
||||
timestamp_tokens = sequence >= timestamp_begin
|
||||
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
|
||||
if sum(timestamp_tokens) > 0:
|
||||
last_timestamp = np.where(timestamp_tokens)[0][-1]
|
||||
consecutive = (
|
||||
np.append(consecutive, last_timestamp + 1) if last_timestamp not in consecutive else consecutive
|
||||
)
|
||||
|
||||
if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens
|
||||
if len(consecutive) > 0:
|
||||
last_slice = 0
|
||||
# take the last timestamp of the previous chunk
|
||||
for current_slice in consecutive:
|
||||
actual_offset = items[-1][-1] if seq_idx != 0 or last_slice != 0 else sequence[0]
|
||||
sliced_tokens = sequence[last_slice:current_slice]
|
||||
# set correct timestamps
|
||||
sliced_tokens[0] += actual_offset
|
||||
sliced_tokens[-1] += actual_offset
|
||||
items.append(sliced_tokens) # correct final sequence
|
||||
last_slice = current_slice
|
||||
# check if we have a non consecutive timestamp at the end
|
||||
if np.where(timestamp_tokens)[0][-1] != current_slice:
|
||||
# offset = items[-1][-1] if len(items) > 0 else timestamp_begin
|
||||
sliced_tokens = sequence[current_slice : np.where(timestamp_tokens)[0][-1] + 1]
|
||||
sliced_tokens[0] += actual_offset
|
||||
sliced_tokens[-1] += actual_offset
|
||||
duration = sliced_tokens[-1] - sliced_tokens[0]
|
||||
sliced_tokens[0] = actual_offset
|
||||
sliced_tokens[-1] = actual_offset + duration
|
||||
items.append(sliced_tokens)
|
||||
else:
|
||||
timestamps = sequence[timestamp_tokens.nonzero()[0].flatten()]
|
||||
if len(timestamps) > 0 and timestamps[-1].item() != timestamp_begin:
|
||||
# no consecutive timestamps but it has a timestamp; use the last one.
|
||||
# single timestamp at the end means no speech after the last timestamp.
|
||||
last_idx = np.argwhere(sequence == timestamps[-1])[0][0]
|
||||
sliced_sequence = sequence[: last_idx + 1]
|
||||
duration = sliced_sequence[-1] - sliced_sequence[0]
|
||||
# We need to discard the previous timing information
|
||||
sliced_sequence[0] = items[-1][-1]
|
||||
sliced_sequence[-1] = items[-1][-1] + duration
|
||||
items.append(sliced_sequence)
|
||||
# The beginning time of the next chunk
|
||||
last_slice = current_slice
|
||||
|
||||
time += chunk_len
|
||||
result = []
|
||||
for i in range(len(items)):
|
||||
|
|
|
|||
|
|
@ -1086,26 +1086,53 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||
generated_ids = model.generate(input_features, max_length=448, logits_processor=timestamp_processor).to("cpu")
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_OUTPUT = torch.tensor([50258, 50259, 50359, 50364, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404])
|
||||
EXPECTED_OUTPUT = torch.tensor([50258, 50259, 50359, 50364, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50692, 50692, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50926, 50926, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 51208, 51208, 949, 505, 11, 14138, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51552, 51552, 634, 575, 12525, 22618, 1968, 6144, 35617, 7354, 1292, 6, 589, 307, 534, 10281, 934, 439, 11, 293, 51836, 51836, 50257])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(generated_ids, EXPECTED_OUTPUT))
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_TRANSCRIPT = [
|
||||
{
|
||||
'text': " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Layton's work is really Greek after all,",
|
||||
'offsets': [
|
||||
{'text': ' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.', 'timestamp': (0.0, 5.62)},
|
||||
{'text': " Nor is Mr. Quilter's manner less interesting than his matter.", 'timestamp': (5.62, 10.36)},
|
||||
{'text': ' He tells us that at this festive season of the year,', 'timestamp': (10.36, 14.46)},
|
||||
{'text': ' with Christmas and roast beef looming before us,', 'timestamp': (14.46, 17.76)},
|
||||
{'text': ' similes drawn from eating and its results occur most readily to the mind.', 'timestamp': (17.76, 22.8)},
|
||||
{'text': " He has grave doubts whether Sir Frederick Layton's work is really Greek after all,", 'timestamp': (22.8, 28.82)}
|
||||
]
|
||||
"text": (
|
||||
" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is"
|
||||
" Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season"
|
||||
" of the year, with Christmas and roast beef looming before us, similarly drawn from eating and"
|
||||
" its results occur most readily to the mind. He has grave doubts whether Sir Frederick Latins'"
|
||||
" work is really Greek after all, and"
|
||||
),
|
||||
"offsets": [
|
||||
{
|
||||
"text": (
|
||||
" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
|
||||
),
|
||||
"timestamp": (0.0, 6.5600000000000005),
|
||||
},
|
||||
{
|
||||
"text": " Nor is Mr. Quilter's manner less interesting than his matter.",
|
||||
"timestamp": (6.5600000000000005, 11.24),
|
||||
},
|
||||
{
|
||||
"text": (
|
||||
" He tells us that at this festive season of the year, with Christmas and roast beef"
|
||||
" looming"
|
||||
),
|
||||
"timestamp": (11.24, 16.88),
|
||||
},
|
||||
{
|
||||
"text": (
|
||||
" before us, similarly drawn from eating and its results occur most readily to the mind."
|
||||
),
|
||||
"timestamp": (16.88, 23.76),
|
||||
},
|
||||
{
|
||||
"text": (
|
||||
" He has grave doubts whether Sir Frederick Latins' work is really Greek after all, and"
|
||||
),
|
||||
"timestamp": (23.76, 29.44),
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
|
|
|||
|
|
@ -344,7 +344,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
|||
},
|
||||
)
|
||||
merge = _find_timestamp_sequence(
|
||||
[[previous_sequence, (3000, 0, 0)], [next_sequences_1, (3000, 750, 0)]],
|
||||
[[previous_sequence, (480_000, 0, 0)], [next_sequences_1, (480_000, 120_000, 0)]],
|
||||
processor.tokenizer,
|
||||
processor.feature_extractor,
|
||||
max_source_positions,
|
||||
|
|
@ -381,7 +381,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
|||
# fmt: on
|
||||
# {'text': ' of spectators, retrievality is not worth thinking about. His instant panic was followed by a small, sharp blow high on his chest.','timestamp': (0.0, 9.4)}
|
||||
merge = _find_timestamp_sequence(
|
||||
[[previous_sequence, (3000, 0, 0)], [next_sequences_2, (3000, 750, 0)]],
|
||||
[[previous_sequence, (480_000, 0, 0)], [next_sequences_2, (480_000, 120_000, 0)]],
|
||||
processor.tokenizer,
|
||||
processor.feature_extractor,
|
||||
max_source_positions,
|
||||
|
|
@ -417,7 +417,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
|||
# fmt: on
|
||||
# {'text': ' His instant panic was followed by a small, sharp blow high on his chest.','timestamp': (0.0, 9.4)}
|
||||
merge = _find_timestamp_sequence(
|
||||
[[previous_sequence, (3000, 0, 0)], [next_sequences_3, (3000, 750, 0)]],
|
||||
[[previous_sequence, (480_000, 0, 0)], [next_sequences_3, (480_000, 120_000, 0)]],
|
||||
processor.tokenizer,
|
||||
processor.feature_extractor,
|
||||
max_source_positions,
|
||||
|
|
@ -447,11 +447,11 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
|||
# last case is when the sequence is not in the first next predicted start and end of timestamp
|
||||
# fmt: off
|
||||
next_sequences_3 = [
|
||||
[50364, 2812, 9836, 14783, 390, 51492, 406, 3163, 1953, 466, 13, 50634, 50634, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 50934]
|
||||
[50364, 2812, 9836, 14783, 390, 406, 3163, 1953, 466, 13, 50634, 50634, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 50934]
|
||||
]
|
||||
# fmt: on
|
||||
merge = _find_timestamp_sequence(
|
||||
[[previous_sequence, (3000, 0, 0)], [next_sequences_3, (3000, 750, 0)]],
|
||||
[[previous_sequence, (480_000, 0, 0)], [next_sequences_3, (480_000, 167_000, 0)]],
|
||||
processor.tokenizer,
|
||||
processor.feature_extractor,
|
||||
max_source_positions,
|
||||
|
|
@ -459,7 +459,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
|||
# fmt: off
|
||||
self.assertEqual(
|
||||
merge,
|
||||
[51492, 406, 3163, 1953, 466, 13, 53112, 53112, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 53332],
|
||||
[51492, 406, 3163, 1953, 466, 13, 51612, 51612, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 51912]
|
||||
)
|
||||
# fmt: on
|
||||
self.assertEqual(
|
||||
|
|
@ -473,7 +473,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
|||
{"text": " not worth thinking about.", "timestamp": (22.56, 24.96)},
|
||||
{
|
||||
"text": " His instant panic was followed by a small, sharp blow high on his chest.",
|
||||
"timestamp": (24.96, 29.36),
|
||||
"timestamp": (24.96, 30.96),
|
||||
},
|
||||
],
|
||||
},
|
||||
|
|
|
|||
Loading…
Reference in a new issue