This commit is contained in:
ydshieh 2024-12-06 06:25:03 +01:00
parent 723b9e2d3a
commit e7af32f234

View file

@ -608,9 +608,35 @@ class TFModelTesterMixin:
# Check predictions on first output (logits/hidden-states) are close enough given low-level computational differences
pt_model.eval()
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs_dict)
tf_outputs = tf_model(tf_inputs_dict)
from contextlib import contextmanager
from transformers.testing_utils import set_model_for_less_flaky_test, set_model_tester_for_less_flaky_test, set_config_for_less_flaky_test
def foo1(func):
def wrap(*args, **kwargs):
kwargs["eps"] = 1.0
return func(*args, **kwargs)
return wrap
def foo2(func):
def wrap(*args, **kwargs):
kwargs["epsilon"] = 1.0
return func(*args, **kwargs)
return wrap
set_model_for_less_flaky_test(pt_model)
set_model_for_less_flaky_test(tf_model)
import unittest
@contextmanager
def patched_norm_layers():
import torch
with unittest.mock.patch.object(torch.nn.functional, "normalize", side_effect=foo1(torch.nn.functional.normalize)):
with unittest.mock.patch.object(tf.math, "l2_normalize", side_effect=foo2(tf.math.l2_normalize)):
yield
with patched_norm_layers():
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs_dict)
tf_outputs = tf_model(tf_inputs_dict)
# tf models returned loss is usually a tensor rather than a scalar.
# (see `hf_compute_loss`: it uses `keras.losses.Reduction.NONE`)
@ -624,9 +650,11 @@ class TFModelTesterMixin:
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
import transformers
from transformers.testing_utils import set_model_for_less_flaky_test, set_model_tester_for_less_flaky_test, set_config_for_less_flaky_test
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
set_config_for_less_flaky_test(config)
# Output all for aggressive testing
config.output_hidden_states = True