diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 9f6780800..e1676399c 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -960,7 +960,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): # take argmax in non-differentiable way # comptute hard codevector distribution (one hot) codevector_idx = hidden_states.argmax(dim=-1) - codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_( + codevector_probs = hidden_states.new_zeros(hidden_states.shape).scatter_( -1, codevector_idx.view(-1, 1), 1.0 ) codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1) diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index 5bee0d040..8723c6338 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -1023,7 +1023,7 @@ class Wav2Vec2ConformerGumbelVectorQuantizer(nn.Module): # take argmax in non-differentiable way # comptute hard codevector distribution (one hot) codevector_idx = hidden_states.argmax(dim=-1) - codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_( + codevector_probs = hidden_states.new_zeros(hidden_states.shape).scatter_( -1, codevector_idx.view(-1, 1), 1.0 ) codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1) diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index c08f6766c..d3255baf8 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -104,6 +104,7 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [ "blenderbot-small", "bloom", "clip", + "convnext", "deberta", "deberta-v2", "distilbert", @@ -125,6 +126,7 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [ "opt", "pegasus", "plbart", + "resnet", "roberta", "speech_to_text", "speech_to_text_2", @@ -133,6 +135,7 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [ "trocr", "vit", "xglm", + "wav2vec2", # "xlnet", ] @@ -743,7 +746,7 @@ class HFTracer(Tracer): elif hasattr(model.config, "encoder"): image_size = model.config.encoder.image_size else: - raise AttributeError('Could not find the "image_size" field in the model config') + image_size = (_generate_random_int(), _generate_random_int()) # If no num_channels is in the config, use some arbitrary value. num_channels = getattr(model.config, "num_channels", 3) diff --git a/tests/models/convnext/test_modeling_convnext.py b/tests/models/convnext/test_modeling_convnext.py index 46ef3ce71..1225175a1 100644 --- a/tests/models/convnext/test_modeling_convnext.py +++ b/tests/models/convnext/test_modeling_convnext.py @@ -137,6 +137,7 @@ class ConvNextModelTest(ModelTesterMixin, unittest.TestCase): else () ) + fx_compatible = True test_pruning = False test_resize_embeddings = False test_head_masking = False diff --git a/tests/models/resnet/test_modeling_resnet.py b/tests/models/resnet/test_modeling_resnet.py index 83f08b68a..557883e0b 100644 --- a/tests/models/resnet/test_modeling_resnet.py +++ b/tests/models/resnet/test_modeling_resnet.py @@ -126,6 +126,7 @@ class ResNetModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (ResNetModel, ResNetForImageClassification) if is_torch_available() else () + fx_compatible = True test_pruning = False test_resize_embeddings = False test_head_masking = False diff --git a/tests/models/wav2vec2/test_modeling_wav2vec2.py b/tests/models/wav2vec2/test_modeling_wav2vec2.py index 21f77b19a..040731472 100644 --- a/tests/models/wav2vec2/test_modeling_wav2vec2.py +++ b/tests/models/wav2vec2/test_modeling_wav2vec2.py @@ -15,6 +15,9 @@ """ Testing suite for the PyTorch Wav2Vec2 model. """ import math +import os +import pickle +import tempfile import unittest import numpy as np @@ -32,6 +35,7 @@ from transformers.testing_utils import ( slow, torch_device, ) +from transformers.utils import is_torch_fx_available from ...test_configuration_common import ConfigTester from ...test_modeling_common import ( @@ -72,6 +76,10 @@ if is_pyctcdecode_available(): from transformers import Wav2Vec2ProcessorWithLM +if is_torch_fx_available(): + from transformers.utils.fx import symbolic_trace + + class Wav2Vec2ModelTester: def __init__( self, @@ -411,6 +419,7 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) + fx_compatible = True test_pruning = False test_headmasking = False @@ -633,6 +642,106 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") self.assertIsNotNone(model) + # Wav2Vec2 cannot be torchscripted because of group norm. + def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False): + if not is_torch_fx_available() or not self.fx_compatible: + return + + configs_no_init = _config_zero_init(config) # To be sure we have no Nan + configs_no_init.return_dict = False + + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + model.to(torch_device) + model.eval() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss) + + try: + input_names = [ + "attention_mask", + "bbox", + "input_features", + "input_ids", + "input_values", + "pixel_values", + "token_type_ids", + "visual_feats", + "visual_pos", + ] + + labels = inputs.get("labels", None) + start_positions = inputs.get("start_positions", None) + end_positions = inputs.get("end_positions", None) + if labels is not None: + input_names.append("labels") + if start_positions is not None: + input_names.append("start_positions") + if end_positions is not None: + input_names.append("end_positions") + + filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} + input_names = list(filtered_inputs.keys()) + + model_output = model(**filtered_inputs) + + if ( + isinstance(model, Wav2Vec2ForSequenceClassification) + and not hasattr(model.config, "problem_type") + or model.config.problem_type is None + ): + model.config.problem_type = "single_label_classification" + + traced_model = symbolic_trace(model, input_names) + traced_output = traced_model(**filtered_inputs) + + except Exception as e: + self.fail(f"Couldn't trace module: {e}") + + def flatten_output(output): + flatten = [] + for x in output: + if isinstance(x, (tuple, list)): + flatten += flatten_output(x) + elif not isinstance(x, torch.Tensor): + continue + else: + flatten.append(x) + return flatten + + model_output = flatten_output(model_output) + traced_output = flatten_output(traced_output) + num_outputs = len(model_output) + + for i in range(num_outputs): + self.assertTrue( + torch.allclose(model_output[i], traced_output[i]), + f"traced {i}th output doesn't match model {i}th output for {model_class}", + ) + + # Test that the model can be serialized and restored properly + with tempfile.TemporaryDirectory() as tmp_dir_name: + pkl_file_name = os.path.join(tmp_dir_name, "model.pkl") + try: + with open(pkl_file_name, "wb") as f: + pickle.dump(traced_model, f) + with open(pkl_file_name, "rb") as f: + loaded = pickle.load(f) + except Exception as e: + self.fail(f"Couldn't serialize / deserialize the traced model: {e}") + + loaded_output = loaded(**filtered_inputs) + loaded_output = flatten_output(loaded_output) + + for i in range(num_outputs): + self.assertTrue( + torch.allclose(model_output[i], loaded_output[i]), + f"serialized model {i}th output doesn't match model {i}th output for {model_class}", + ) + + # Avoid memory leak. Without this, each call increase RAM usage by ~20MB. + # (Even with this call, there are still memory leak by ~0.04MB) + self.clear_torch_jit_class_registry() + @require_torch class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):