From 02451cda74acff9a0873f36acbcbfcdcf9afe24b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 9 Feb 2021 11:49:02 +0300 Subject: [PATCH] Deprecate Wav2Vec2ForMaskedLM and add Wav2Vec2ForCTC (#10089) * add wav2vec2CTC and deprecate for maskedlm * remove from docs --- docs/source/model_doc/wav2vec2.rst | 4 +- src/transformers/__init__.py | 2 + src/transformers/models/wav2vec2/__init__.py | 2 + ..._original_pytorch_checkpoint_to_pytorch.py | 4 +- .../models/wav2vec2/modeling_wav2vec2.py | 81 ++++++++++++++++++- src/transformers/utils/dummy_pt_objects.py | 5 ++ tests/test_modeling_wav2vec2.py | 10 +-- utils/check_repo.py | 2 + 8 files changed, 100 insertions(+), 10 deletions(-) diff --git a/docs/source/model_doc/wav2vec2.rst b/docs/source/model_doc/wav2vec2.rst index 6df73a492..3dd6e293b 100644 --- a/docs/source/model_doc/wav2vec2.rst +++ b/docs/source/model_doc/wav2vec2.rst @@ -58,8 +58,8 @@ Wav2Vec2Model :members: forward -Wav2Vec2ForMaskedLM +Wav2Vec2ForCTC ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: transformers.Wav2Vec2ForMaskedLM +.. autoclass:: transformers.Wav2Vec2ForCTC :members: forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 83f0ee77b..e679bc57d 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -367,6 +367,7 @@ if is_torch_available(): _import_structure["models.wav2vec2"].extend( [ "WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", + "Wav2Vec2ForCTC", "Wav2Vec2ForMaskedLM", "Wav2Vec2Model", "Wav2Vec2PreTrainedModel", @@ -1813,6 +1814,7 @@ if TYPE_CHECKING: ) from .models.wav2vec2 import ( WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, + Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2PreTrainedModel, diff --git a/src/transformers/models/wav2vec2/__init__.py b/src/transformers/models/wav2vec2/__init__.py index e9b13de2c..22066fadf 100644 --- a/src/transformers/models/wav2vec2/__init__.py +++ b/src/transformers/models/wav2vec2/__init__.py @@ -29,6 +29,7 @@ if is_torch_available(): _import_structure["modeling_wav2vec2"] = [ "WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", "Wav2Vec2ForMaskedLM", + "Wav2Vec2ForCTC", "Wav2Vec2Model", "Wav2Vec2PreTrainedModel", ] @@ -41,6 +42,7 @@ if TYPE_CHECKING: if is_torch_available(): from .modeling_wav2vec2 import ( WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, + Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2PreTrainedModel, diff --git a/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py index 110db7fb2..cbe74b8c7 100644 --- a/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py @@ -20,7 +20,7 @@ import argparse import fairseq import torch -from transformers import Wav2Vec2Config, Wav2Vec2ForMaskedLM, logging +from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, logging logging.set_verbosity_info() @@ -141,7 +141,7 @@ def convert_wav2vec2_checkpoint(checkpoint_path, pytorch_dump_folder_path, dict_ """ Copy/paste/tweak model's weights to transformers design. """ - hf_wav2vec = Wav2Vec2ForMaskedLM(Wav2Vec2Config()) + hf_wav2vec = Wav2Vec2ForCTC(Wav2Vec2Config()) model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( [checkpoint_path], arg_overrides={"data": dict_path} diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index ba151a847..312b1bff3 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -15,6 +15,7 @@ """ PyTorch Wav2Vec2 model. """ +import warnings from typing import Optional, Tuple import torch @@ -24,7 +25,7 @@ from torch import nn from ...activations import ACT2FN from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings -from ...modeling_outputs import BaseModelOutput, MaskedLMOutput +from ...modeling_outputs import BaseModelOutput, CausalLMOutput, MaskedLMOutput from ...modeling_utils import PreTrainedModel from ...utils import logging from .configuration_wav2vec2 import Wav2Vec2Config @@ -665,6 +666,10 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel): def __init__(self, config): super().__init__(config) + warnings.warn( + "The class `Wav2Vec2ForMaskedLM` is deprecated. Please use `Wav2Vec2ForCTC` instead.", FutureWarning + ) + self.wav2vec2 = Wav2Vec2Model(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) @@ -729,3 +734,77 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel): return output return MaskedLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) + + +@add_start_docstrings( + """Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC). """, + WAV_2_VEC_2_START_DOCSTRING, +) +class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.wav2vec2 = Wav2Vec2Model(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) + + self.init_weights() + + @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_values, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + labels=None, + ): + r""" + labels (:obj:`Float.LongTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + TODO(PVP): Fill out when adding training + + Returns: + + Example:: + + >>> from transformers import Wav2Vec2Tokenizer, Wav2Vec2Model + >>> from datasets import load_dataset + >>> import soundfile as sf + + >>> tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h") + >>> model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") + + >>> def map_to_array(batch): + >>> speech, _ = sf.read(batch["file"]) + >>> batch["speech"] = speech + >>> return batch + + >>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") + >>> ds = ds.map(map_to_array) + + >>> input_values = tokenizer(ds["speech"][0], return_tensors="pt").input_values # Batch size 1 + >>> logits = model(input_values).logits + + >>> predicted_ids = torch.argmax(logits, dim=-1) + >>> transcription = tokenizer.decode(predicted_ids[0]) + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.wav2vec2( + input_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.lm_head(hidden_states) + + if not return_dict: + output = (logits,) + outputs[1:] + return output + + return CausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index b246074fc..63dede9b2 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -2229,6 +2229,11 @@ def load_tf_weights_in_transfo_xl(*args, **kwargs): WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None +class Wav2Vec2ForCTC: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + class Wav2Vec2ForMaskedLM: def __init__(self, *args, **kwargs): requires_pytorch(self) diff --git a/tests/test_modeling_wav2vec2.py b/tests/test_modeling_wav2vec2.py index 8f504959f..b9e726633 100644 --- a/tests/test_modeling_wav2vec2.py +++ b/tests/test_modeling_wav2vec2.py @@ -29,7 +29,7 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init if is_torch_available(): import torch - from transformers import Wav2Vec2Config, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2Tokenizer + from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2Tokenizer class Wav2Vec2ModelTester: @@ -204,7 +204,7 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): @require_torch class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = (Wav2Vec2Model, Wav2Vec2ForMaskedLM) if is_torch_available() else () + all_model_classes = (Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForCTC) if is_torch_available() else () test_pruning = False test_headmasking = False test_torchscript = False @@ -289,7 +289,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): return ds["speech"][:num_samples] def test_inference_masked_lm_normal(self): - model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-base-960h") + model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") model.to(torch_device) tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True) @@ -307,7 +307,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS) def test_inference_masked_lm_normal_batched(self): - model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-base-960h") + model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") model.to(torch_device) tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True) @@ -330,7 +330,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS) def test_inference_masked_lm_robust_batched(self): - model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to(torch_device) + model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to(torch_device) tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", do_lower_case=True) input_speech = self._load_datasamples(4) diff --git a/utils/check_repo.py b/utils/check_repo.py index d0ca63ef1..f9c25dabc 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -118,6 +118,7 @@ IGNORE_NON_AUTO_CONFIGURED = [ "TFMT5EncoderModel", "TFOpenAIGPTDoubleHeadsModel", "TFT5EncoderModel", + "Wav2Vec2ForCTC", "XLMForQuestionAnswering", "XLMProphetNetDecoder", "XLMProphetNetEncoder", @@ -370,6 +371,7 @@ DEPRECATED_OBJECTS = [ "TFBartPretrainedModel", "TextDataset", "TextDatasetForNextSentencePrediction", + "Wav2Vec2ForMaskedLM", "glue_compute_metrics", "glue_convert_examples_to_features", "glue_output_modes",