From 56b03c96b865a40811f4eb2942e71aaab4cd38c2 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 13 Feb 2023 21:23:00 +0000 Subject: [PATCH] Fix TF CTC tests (#21606) --- .../models/hubert/test_modeling_tf_hubert.py | 26 ++++++++++++++----- .../wav2vec2/test_modeling_tf_wav2vec2.py | 4 +-- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/tests/models/hubert/test_modeling_tf_hubert.py b/tests/models/hubert/test_modeling_tf_hubert.py index 3cd810118..b20119c64 100644 --- a/tests/models/hubert/test_modeling_tf_hubert.py +++ b/tests/models/hubert/test_modeling_tf_hubert.py @@ -321,6 +321,20 @@ class TFHubertModelTest(TFModelTesterMixin, unittest.TestCase): model = TFHubertModel.from_pretrained("facebook/hubert-base-ls960") self.assertIsNotNone(model) + # We override here as passing a full batch of 13 samples results in OOM errors for CTC + def test_dataset_conversion(self): + default_batch_size = self.model_tester.batch_size + self.model_tester.batch_size = 2 + super().test_dataset_conversion() + self.model_tester.batch_size = default_batch_size + + # We override here as passing a full batch of 13 samples results in OOM errors for CTC + def test_keras_fit(self): + default_batch_size = self.model_tester.batch_size + self.model_tester.batch_size = 2 + super().test_keras_fit() + self.model_tester.batch_size = default_batch_size + @require_tf class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase): @@ -431,20 +445,18 @@ class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase): def test_model_common_attributes(self): pass + @slow + def test_model_from_pretrained(self): + model = TFHubertModel.from_pretrained("facebook/hubert-large-ls960-ft") + self.assertIsNotNone(model) + # We override here as passing a full batch of 13 samples results in OOM errors for CTC - # TODO: fix me - @unittest.skip(reason="Crashing on CI, temporarily skipped") def test_dataset_conversion(self): default_batch_size = self.model_tester.batch_size self.model_tester.batch_size = 2 super().test_dataset_conversion() self.model_tester.batch_size = default_batch_size - @slow - def test_model_from_pretrained(self): - model = TFHubertModel.from_pretrained("facebook/hubert-large-ls960-ft") - self.assertIsNotNone(model) - # We override here as passing a full batch of 13 samples results in OOM errors for CTC def test_keras_fit(self): default_batch_size = self.model_tester.batch_size diff --git a/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py b/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py index 2e3c2c26c..d8e3f52a0 100644 --- a/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py +++ b/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py @@ -396,7 +396,7 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, unittest.TestCase): def test_keras_fit(self): default_batch_size = self.model_tester.batch_size self.model_tester.batch_size = 2 - super().test_dataset_conversion() + super().test_keras_fit() self.model_tester.batch_size = default_batch_size @@ -527,7 +527,7 @@ class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase): def test_keras_fit(self): default_batch_size = self.model_tester.batch_size self.model_tester.batch_size = 2 - super().test_dataset_conversion() + super().test_keras_fit() self.model_tester.batch_size = default_batch_size