diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index eb328d83e..d5b16a827 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -608,9 +608,35 @@ class TFModelTesterMixin: # Check predictions on first output (logits/hidden-states) are close enough given low-level computational differences pt_model.eval() - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs_dict) - tf_outputs = tf_model(tf_inputs_dict) + from contextlib import contextmanager + from transformers.testing_utils import set_model_for_less_flaky_test, set_model_tester_for_less_flaky_test, set_config_for_less_flaky_test + def foo1(func): + def wrap(*args, **kwargs): + kwargs["eps"] = 1.0 + return func(*args, **kwargs) + return wrap + + def foo2(func): + def wrap(*args, **kwargs): + kwargs["epsilon"] = 1.0 + return func(*args, **kwargs) + return wrap + + set_model_for_less_flaky_test(pt_model) + set_model_for_less_flaky_test(tf_model) + + import unittest + @contextmanager + def patched_norm_layers(): + import torch + with unittest.mock.patch.object(torch.nn.functional, "normalize", side_effect=foo1(torch.nn.functional.normalize)): + with unittest.mock.patch.object(tf.math, "l2_normalize", side_effect=foo2(tf.math.l2_normalize)): + yield + + with patched_norm_layers(): + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs_dict) + tf_outputs = tf_model(tf_inputs_dict) # tf models returned loss is usually a tensor rather than a scalar. # (see `hf_compute_loss`: it uses `keras.losses.Reduction.NONE`) @@ -624,9 +650,11 @@ class TFModelTesterMixin: @is_pt_tf_cross_test def test_pt_tf_model_equivalence(self, allow_missing_keys=False): import transformers + from transformers.testing_utils import set_model_for_less_flaky_test, set_model_tester_for_less_flaky_test, set_config_for_less_flaky_test for model_class in self.all_model_classes: config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + set_config_for_less_flaky_test(config) # Output all for aggressive testing config.output_hidden_states = True