mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
[Test] Fix W2V-Conformer integration test (#17303)
* [Test] Fix W2V-Conformer integration test * correct w2v2 * up
This commit is contained in:
parent
28a0811652
commit
10704e1209
3 changed files with 7 additions and 9 deletions
|
|
@ -1414,7 +1414,6 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
|
|||
>>> from transformers import AutoFeatureExtractor, Wav2Vec2ForPreTraining
|
||||
>>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
|
||||
>>> from datasets import load_dataset
|
||||
>>> import soundfile as sf
|
||||
|
||||
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
|
||||
>>> model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base")
|
||||
|
|
|
|||
|
|
@ -1442,7 +1442,7 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
|
|||
|
||||
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=Wav2Vec2ConformerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2-base->wav2vec2-conformer-rel-pos-large,wav2vec2->wav2vec2_conformer
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,wav2vec2_conformer-base->wav2vec2-conformer-rel-pos-large
|
||||
def forward(
|
||||
self,
|
||||
input_values: Optional[torch.Tensor],
|
||||
|
|
@ -1470,14 +1470,9 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
|
|||
>>> from transformers import AutoFeatureExtractor, Wav2Vec2ConformerForPreTraining
|
||||
>>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import _compute_mask_indices
|
||||
>>> from datasets import load_dataset
|
||||
>>> import soundfile as sf
|
||||
|
||||
>>> feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
... "facebook/wav2vec2_conformer-conformer-rel-pos-large"
|
||||
... )
|
||||
>>> model = Wav2Vec2ConformerForPreTraining.from_pretrained(
|
||||
... "facebook/wav2vec2_conformer-conformer-rel-pos-large"
|
||||
... )
|
||||
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
|
||||
>>> model = Wav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
|
||||
|
||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1
|
||||
|
|
|
|||
|
|
@ -581,6 +581,10 @@ class Wav2Vec2ConformerModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
module.weight_v.data.fill_(3)
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
module.bias.data.fill_(3)
|
||||
if hasattr(module, "pos_bias_u") and module.pos_bias_u is not None:
|
||||
module.pos_bias_u.data.fill_(3)
|
||||
if hasattr(module, "pos_bias_v") and module.pos_bias_v is not None:
|
||||
module.pos_bias_v.data.fill_(3)
|
||||
if hasattr(module, "codevectors") and module.codevectors is not None:
|
||||
module.codevectors.data.fill_(3)
|
||||
if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
|
||||
|
|
|
|||
Loading…
Reference in a new issue