mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Fix pad_to_max_length Whisper (#30787)
* fix pad_to_max_length Whisper * add tests * make style
This commit is contained in:
parent
b84cd67526
commit
d355741eca
2 changed files with 84 additions and 3 deletions
|
|
@ -122,7 +122,9 @@ def _get_attr_from_logit_processors(logits_processor, logit_processor_class, att
|
|||
return None
|
||||
|
||||
|
||||
def _pad_to_max_length(current_segments, pad_token_id, padding="right", bos_token_tensor=None, cut_off_length=None):
|
||||
def _pad_to_max_length(
|
||||
current_segments, pad_token_id, device, padding="right", bos_token_tensor=None, cut_off_length=None
|
||||
):
|
||||
max_total_length = 0
|
||||
sequences = []
|
||||
if padding not in ["right", "left"]:
|
||||
|
|
@ -143,7 +145,7 @@ def _pad_to_max_length(current_segments, pad_token_id, padding="right", bos_toke
|
|||
elif bos_token_tensor is not None:
|
||||
sequences.append(bos_token_tensor)
|
||||
else:
|
||||
sequences.append(torch.tensor([]))
|
||||
sequences.append(torch.tensor([], device=device))
|
||||
|
||||
for i in range(len(current_segments)):
|
||||
pad_length = max_total_length - len(sequences[i])
|
||||
|
|
@ -733,7 +735,9 @@ class WhisperGenerationMixin:
|
|||
if (prompt_ids is not None and generation_config.prompt_condition_type == "first-segment")
|
||||
else current_segments
|
||||
)
|
||||
sequences = _pad_to_max_length(final_segments, generation_config.pad_token_id, padding="right")
|
||||
sequences = _pad_to_max_length(
|
||||
final_segments, generation_config.pad_token_id, device=self.device, padding="right"
|
||||
)
|
||||
|
||||
# 8. If we return all segments, the predicted output sequences are put under `"sequences"`.
|
||||
if return_segments:
|
||||
|
|
@ -1506,6 +1510,7 @@ class WhisperGenerationMixin:
|
|||
prev_tokens = _pad_to_max_length(
|
||||
active_segments,
|
||||
generation_config.pad_token_id,
|
||||
device=device,
|
||||
padding="left",
|
||||
bos_token_tensor=prev_ids,
|
||||
cut_off_length=cut_off_length,
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ from transformers.testing_utils import (
|
|||
require_torch,
|
||||
require_torch_fp16,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
require_torchaudio,
|
||||
slow,
|
||||
torch_device,
|
||||
|
|
@ -2866,6 +2867,81 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||
for i in range(num_samples):
|
||||
assert decoded_all[i] == EXPECTED_TEXT[i]
|
||||
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
def test_whisper_empty_longform(self):
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
|
||||
model = model.to(torch_device)
|
||||
|
||||
ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
|
||||
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
|
||||
|
||||
num_samples = 8
|
||||
|
||||
audio = ds[:num_samples]["audio"]
|
||||
audios = [x["array"] for x in audio]
|
||||
audios[0][:] = np.zeros(audios[0].shape)
|
||||
|
||||
inputs = processor(
|
||||
audios,
|
||||
return_tensors="pt",
|
||||
truncation=False,
|
||||
padding="longest",
|
||||
return_attention_mask=True,
|
||||
sampling_rate=16_000,
|
||||
)
|
||||
inputs = inputs.to(device=torch_device)
|
||||
|
||||
gen_kwargs = {
|
||||
"no_speech_threshold": 0.2,
|
||||
"temperature": (0.0,),
|
||||
"logprob_threshold": 0.0, # Ignore logprob, use only no-speech prob
|
||||
"num_beams": 5,
|
||||
"language": "fr",
|
||||
"task": "transcribe",
|
||||
}
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.generate(**inputs, **gen_kwargs)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
@slow
|
||||
def test_whisper_empty_longform_multi_gpu(self):
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny", device_map="auto")
|
||||
|
||||
ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
|
||||
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
|
||||
|
||||
num_samples = 8
|
||||
|
||||
audio = ds[:num_samples]["audio"]
|
||||
audios = [x["array"] for x in audio]
|
||||
audios[0][:] = np.zeros(audios[0].shape)
|
||||
|
||||
inputs = processor(
|
||||
audios,
|
||||
return_tensors="pt",
|
||||
truncation=False,
|
||||
padding="longest",
|
||||
return_attention_mask=True,
|
||||
sampling_rate=16_000,
|
||||
)
|
||||
inputs = inputs.to(device=model.device)
|
||||
|
||||
gen_kwargs = {
|
||||
"no_speech_threshold": 0.2,
|
||||
"temperature": (0.0,),
|
||||
"logprob_threshold": 0.0, # Ignore logprob, use only no-speech prob
|
||||
"num_beams": 5,
|
||||
"language": "fr",
|
||||
"task": "transcribe",
|
||||
}
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.generate(**inputs, **gen_kwargs)
|
||||
|
||||
|
||||
def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
|
||||
if head_mask is None:
|
||||
|
|
|
|||
Loading…
Reference in a new issue