mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Allow dict input for audio classification pipeline (#23445)
* Allow dict input for audio classification pipeline * make style * Empty commit to trigger CI * Empty commit to trigger CI * check for torchaudio * add pip instructions Co-authored-by: Sylvain <sylvain.gugger@gmail.com> * Update src/transformers/pipelines/audio_classification.py Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> * asr -> audio class * asr -> audio class --------- Co-authored-by: Sylvain <sylvain.gugger@gmail.com> Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
parent
a6f37f8879
commit
8767958fc1
2 changed files with 50 additions and 8 deletions
|
|
@ -17,7 +17,7 @@ from typing import Union
|
|||
import numpy as np
|
||||
import requests
|
||||
|
||||
from ..utils import add_end_docstrings, is_torch_available, logging
|
||||
from ..utils import add_end_docstrings, is_torch_available, is_torchaudio_available, logging
|
||||
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||
|
||||
|
||||
|
|
@ -110,12 +110,18 @@ class AudioClassificationPipeline(Pipeline):
|
|||
information.
|
||||
|
||||
Args:
|
||||
inputs (`np.ndarray` or `bytes` or `str`):
|
||||
The inputs is either a raw waveform (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`)
|
||||
at the correct sampling rate (no further check will be done) or a `str` that is the filename of the
|
||||
audio file, the file will be read at the correct sampling rate to get the waveform using *ffmpeg*. This
|
||||
requires *ffmpeg* to be installed on the system. If *inputs* is `bytes` it is supposed to be the
|
||||
content of an audio file and is interpreted by *ffmpeg* in the same way.
|
||||
inputs (`np.ndarray` or `bytes` or `str` or `dict`):
|
||||
The inputs is either :
|
||||
- `str` that is the filename of the audio file, the file will be read at the correct sampling rate
|
||||
to get the waveform using *ffmpeg*. This requires *ffmpeg* to be installed on the system.
|
||||
- `bytes` it is supposed to be the content of an audio file and is interpreted by *ffmpeg* in the
|
||||
same way.
|
||||
- (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`)
|
||||
Raw audio at the correct sampling rate (no further check will be done)
|
||||
- `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this
|
||||
pipeline do the resampling. The dict must be either be in the format `{"sampling_rate": int,
|
||||
"raw": np.array}`, or `{"sampling_rate": int, "array": np.array}`, where the key `"raw"` or
|
||||
`"array"` is used to denote the raw audio waveform.
|
||||
top_k (`int`, *optional*, defaults to None):
|
||||
The number of top labels that will be returned by the pipeline. If the provided number is `None` or
|
||||
higher than the number of labels available in the model configuration, it will default to the number of
|
||||
|
|
@ -151,10 +157,42 @@ class AudioClassificationPipeline(Pipeline):
|
|||
if isinstance(inputs, bytes):
|
||||
inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)
|
||||
|
||||
if isinstance(inputs, dict):
|
||||
# Accepting `"array"` which is the key defined in `datasets` for
|
||||
# better integration
|
||||
if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)):
|
||||
raise ValueError(
|
||||
"When passing a dictionary to AudioClassificationPipeline, the dict needs to contain a "
|
||||
'"raw" key containing the numpy array representing the audio and a "sampling_rate" key, '
|
||||
"containing the sampling_rate associated with that array"
|
||||
)
|
||||
|
||||
_inputs = inputs.pop("raw", None)
|
||||
if _inputs is None:
|
||||
# Remove path which will not be used from `datasets`.
|
||||
inputs.pop("path", None)
|
||||
_inputs = inputs.pop("array", None)
|
||||
in_sampling_rate = inputs.pop("sampling_rate")
|
||||
inputs = _inputs
|
||||
if in_sampling_rate != self.feature_extractor.sampling_rate:
|
||||
import torch
|
||||
|
||||
if is_torchaudio_available():
|
||||
from torchaudio import functional as F
|
||||
else:
|
||||
raise ImportError(
|
||||
"torchaudio is required to resample audio samples in AudioClassificationPipeline. "
|
||||
"The torchaudio package can be installed through: `pip install torchaudio`."
|
||||
)
|
||||
|
||||
inputs = F.resample(
|
||||
torch.from_numpy(inputs), in_sampling_rate, self.feature_extractor.sampling_rate
|
||||
).numpy()
|
||||
|
||||
if not isinstance(inputs, np.ndarray):
|
||||
raise ValueError("We expect a numpy ndarray as input")
|
||||
if len(inputs.shape) != 1:
|
||||
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
|
||||
raise ValueError("We expect a single channel audio input for AudioClassificationPipeline")
|
||||
|
||||
processed = self.feature_extractor(
|
||||
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
|
||||
|
|
|
|||
|
|
@ -103,6 +103,10 @@ class AudioClassificationPipelineTests(unittest.TestCase):
|
|||
]
|
||||
self.assertIn(nested_simplify(output, decimals=4), [EXPECTED_OUTPUT, EXPECTED_OUTPUT_PT_2])
|
||||
|
||||
audio_dict = {"array": np.ones((8000,)), "sampling_rate": audio_classifier.feature_extractor.sampling_rate}
|
||||
output = audio_classifier(audio_dict, top_k=4)
|
||||
self.assertIn(nested_simplify(output, decimals=4), [EXPECTED_OUTPUT, EXPECTED_OUTPUT_PT_2])
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_large_model_pt(self):
|
||||
|
|
|
|||
Loading…
Reference in a new issue