mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
if output is tuple like facebook/hf-seamless-m4t-medium, waveform is … (#29722)
* if output is tuple like facebook/hf-seamless-m4t-medium, waveform is the first element Signed-off-by: Wang, Yi <yi.a.wang@intel.com> * add test and fix batch issue Signed-off-by: Wang, Yi <yi.a.wang@intel.com> * add dict output support for seamless_m4t Signed-off-by: Wang, Yi <yi.a.wang@intel.com> --------- Signed-off-by: Wang, Yi <yi.a.wang@intel.com>
This commit is contained in:
parent
8b52fa6b42
commit
79d62b2da2
4 changed files with 29 additions and 3 deletions
|
|
@ -3496,7 +3496,6 @@ class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel):
|
|||
self.device
|
||||
)
|
||||
kwargs_speech["decoder_input_ids"] = t2u_decoder_input_ids
|
||||
|
||||
# second generation
|
||||
unit_ids = self.t2u_model.generate(inputs_embeds=t2u_input_embeds, **kwargs_speech)
|
||||
output_unit_ids = unit_ids.detach().clone()
|
||||
|
|
|
|||
|
|
@ -128,9 +128,12 @@ class PipelineIterator(IterableDataset):
|
|||
# Try to infer the size of the batch
|
||||
if isinstance(processed, torch.Tensor):
|
||||
first_tensor = processed
|
||||
elif isinstance(processed, tuple):
|
||||
first_tensor = processed[0]
|
||||
else:
|
||||
key = list(processed.keys())[0]
|
||||
first_tensor = processed[key]
|
||||
|
||||
if isinstance(first_tensor, list):
|
||||
observed_batch_size = len(first_tensor)
|
||||
else:
|
||||
|
|
@ -140,7 +143,7 @@ class PipelineIterator(IterableDataset):
|
|||
# elements.
|
||||
self.loader_batch_size = observed_batch_size
|
||||
# Setting internal index to unwrap the batch
|
||||
self._loader_batch_data = processed
|
||||
self._loader_batch_data = processed[0] if isinstance(processed, tuple) else processed
|
||||
self._loader_batch_index = 0
|
||||
return self.loader_batch_item()
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -200,7 +200,10 @@ class TextToAudioPipeline(Pipeline):
|
|||
|
||||
def postprocess(self, waveform):
|
||||
output_dict = {}
|
||||
|
||||
if isinstance(waveform, dict):
|
||||
waveform = waveform["waveform"]
|
||||
elif isinstance(waveform, tuple):
|
||||
waveform = waveform[0]
|
||||
output_dict["audio"] = waveform.cpu().float().numpy()
|
||||
output_dict["sampling_rate"] = self.sampling_rate
|
||||
|
||||
|
|
|
|||
|
|
@ -66,6 +66,27 @@ class TextToAudioPipelineTests(unittest.TestCase):
|
|||
audio = [output["audio"] for output in outputs]
|
||||
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_medium_seamless_m4t_pt(self):
|
||||
speech_generator = pipeline(task="text-to-audio", model="facebook/hf-seamless-m4t-medium", framework="pt")
|
||||
|
||||
for forward_params in [{"tgt_lang": "eng"}, {"return_intermediate_token_ids": True, "tgt_lang": "eng"}]:
|
||||
outputs = speech_generator("This is a test", forward_params=forward_params)
|
||||
self.assertEqual({"audio": ANY(np.ndarray), "sampling_rate": 16000}, outputs)
|
||||
|
||||
# test two examples side-by-side
|
||||
outputs = speech_generator(["This is a test", "This is a second test"], forward_params=forward_params)
|
||||
audio = [output["audio"] for output in outputs]
|
||||
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
|
||||
|
||||
# test batching
|
||||
outputs = speech_generator(
|
||||
["This is a test", "This is a second test"], forward_params=forward_params, batch_size=2
|
||||
)
|
||||
audio = [output["audio"] for output in outputs]
|
||||
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_small_bark_pt(self):
|
||||
|
|
|
|||
Loading…
Reference in a new issue