mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
fix
This commit is contained in:
parent
723b9e2d3a
commit
e7af32f234
1 changed files with 31 additions and 3 deletions
|
|
@ -608,9 +608,35 @@ class TFModelTesterMixin:
|
|||
# Check predictions on first output (logits/hidden-states) are close enough given low-level computational differences
|
||||
pt_model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs_dict)
|
||||
tf_outputs = tf_model(tf_inputs_dict)
|
||||
from contextlib import contextmanager
|
||||
from transformers.testing_utils import set_model_for_less_flaky_test, set_model_tester_for_less_flaky_test, set_config_for_less_flaky_test
|
||||
def foo1(func):
|
||||
def wrap(*args, **kwargs):
|
||||
kwargs["eps"] = 1.0
|
||||
return func(*args, **kwargs)
|
||||
return wrap
|
||||
|
||||
def foo2(func):
|
||||
def wrap(*args, **kwargs):
|
||||
kwargs["epsilon"] = 1.0
|
||||
return func(*args, **kwargs)
|
||||
return wrap
|
||||
|
||||
set_model_for_less_flaky_test(pt_model)
|
||||
set_model_for_less_flaky_test(tf_model)
|
||||
|
||||
import unittest
|
||||
@contextmanager
|
||||
def patched_norm_layers():
|
||||
import torch
|
||||
with unittest.mock.patch.object(torch.nn.functional, "normalize", side_effect=foo1(torch.nn.functional.normalize)):
|
||||
with unittest.mock.patch.object(tf.math, "l2_normalize", side_effect=foo2(tf.math.l2_normalize)):
|
||||
yield
|
||||
|
||||
with patched_norm_layers():
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs_dict)
|
||||
tf_outputs = tf_model(tf_inputs_dict)
|
||||
|
||||
# tf models returned loss is usually a tensor rather than a scalar.
|
||||
# (see `hf_compute_loss`: it uses `keras.losses.Reduction.NONE`)
|
||||
|
|
@ -624,9 +650,11 @@ class TFModelTesterMixin:
|
|||
@is_pt_tf_cross_test
|
||||
def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
|
||||
import transformers
|
||||
from transformers.testing_utils import set_model_for_less_flaky_test, set_model_tester_for_less_flaky_test, set_config_for_less_flaky_test
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
set_config_for_less_flaky_test(config)
|
||||
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
|
|
|
|||
Loading…
Reference in a new issue