mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
[pipeline] fix padding for 1-d tensors (#31776)
* [pipeline] fix padding for 1-d tensors * add test * make style * Update tests/pipelines/test_pipelines_automatic_speech_recognition.py Co-authored-by: Kamil Akesbi <45195979+kamilakesbi@users.noreply.github.com> * Update tests/pipelines/test_pipelines_automatic_speech_recognition.py --------- Co-authored-by: Kamil Akesbi <45195979+kamilakesbi@users.noreply.github.com>
This commit is contained in:
parent
3fbaaaa64d
commit
7f5d644e69
2 changed files with 20 additions and 0 deletions
|
|
@ -90,6 +90,9 @@ def _pad(items, key, padding_value, padding_side):
|
|||
# Others include `attention_mask` etc...
|
||||
shape = items[0][key].shape
|
||||
dim = len(shape)
|
||||
if dim == 1:
|
||||
# We have a list of 1-dim torch tensors, which can be stacked without padding
|
||||
return torch.cat([item[key] for item in items], dim=0)
|
||||
if key in ["pixel_values", "image"]:
|
||||
# This is probable image so padding shouldn't be necessary
|
||||
# B, C, H, W
|
||||
|
|
|
|||
|
|
@ -549,6 +549,23 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||
output = speech_recognizer([filename], chunk_length_s=5, batch_size=4)
|
||||
self.assertEqual(output, [{"text": " A man said to the universe, Sir, I exist."}])
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_torch_whisper_batched(self):
|
||||
speech_recognizer = pipeline(
|
||||
task="automatic-speech-recognition",
|
||||
model="openai/whisper-tiny",
|
||||
framework="pt",
|
||||
)
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:2]")
|
||||
EXPECTED_OUTPUT = [
|
||||
{"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."},
|
||||
{"text": " Nor is Mr. Quilters' manner less interesting than his matter."},
|
||||
]
|
||||
|
||||
output = speech_recognizer(ds["audio"], batch_size=2)
|
||||
self.assertEqual(output, EXPECTED_OUTPUT)
|
||||
|
||||
@slow
|
||||
def test_find_longest_common_subsequence(self):
|
||||
max_source_positions = 1500
|
||||
|
|
|
|||
Loading…
Reference in a new issue