mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Fix inverted conditional in TF common test! (#22540)
* Fix inverted conditional in TF common test! * Make the same change in the PT tests file * Make sure hidden states for GPT2 have the same output shape in PT/TF * Minor fix to PT implementation of token classification loss * Skip loss equivalence test for TFHubert because it keeps overflowing to inf * Compute LM loss for TF the (weird) way it's computed in PT * Skip loss equivalence test for Wav2Vec2 for the same reason as Hubert * Fix - don't try to access the hidden states property when output is a tuple
This commit is contained in:
parent
48fbd8fa2e
commit
edb704b26e
7 changed files with 245 additions and 17 deletions
|
|
@ -1228,16 +1228,7 @@ class EsmForTokenClassification(EsmPreTrainedModel):
|
|||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
# Only keep active parts of the loss
|
||||
if attention_mask is not None:
|
||||
active_loss = attention_mask.view(-1) == 1
|
||||
active_logits = logits.view(-1, self.num_labels)
|
||||
active_labels = torch.where(
|
||||
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
||||
)
|
||||
loss = loss_fct(active_logits, active_labels)
|
||||
else:
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
|
|
|
|||
|
|
@ -1051,6 +1051,12 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
|||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])
|
||||
if return_dict and output_hidden_states:
|
||||
# We do this to match the slightly odd PT behaviour - the final hidden state is reshaped to rank 4 when the
|
||||
# input is rank 3, but all other hidden states remain at rank-3 (with the first 2 dims merged)
|
||||
all_hidden_states = transformer_outputs.hidden_states[:-1] + (hidden_states,)
|
||||
else:
|
||||
all_hidden_states = None
|
||||
lm_logits = self.transformer.wte(hidden_states, mode="linear")
|
||||
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training)
|
||||
mc_logits = tf.squeeze(mc_logits, axis=-1)
|
||||
|
|
@ -1062,7 +1068,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
|||
logits=lm_logits,
|
||||
mc_logits=mc_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -953,9 +953,11 @@ class TFXGLMForCausalLM(TFXGLMPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||
loss = None
|
||||
if labels is not None:
|
||||
# shift labels to the left and cut last logit token
|
||||
shifted_logits = lm_logits[:, :-1]
|
||||
labels = labels[:, 1:]
|
||||
loss = self.hf_compute_loss(labels, shifted_logits)
|
||||
labels = tf.concat(
|
||||
[labels[:, 1:], tf.fill((labels.shape[0], 1), tf.cast(self.config.pad_token_id, labels.dtype))],
|
||||
axis=-1,
|
||||
)
|
||||
loss = self.hf_compute_loss(labels, lm_logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + outputs[1:]
|
||||
|
|
|
|||
|
|
@ -17,13 +17,15 @@
|
|||
import copy
|
||||
import inspect
|
||||
import math
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from transformers import is_tf_available
|
||||
from transformers.testing_utils import require_soundfile, require_tf, slow
|
||||
from transformers.testing_utils import is_pt_tf_cross_test, require_soundfile, require_tf, slow
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
|
|
@ -333,6 +335,62 @@ class TFHubertModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
|||
# TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC
|
||||
pass
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
|
||||
# We override the base test here to skip loss calculation for Hubert models because the loss is massive with
|
||||
# the default labels and frequently overflows to inf or exceeds numerical tolerances between TF/PT
|
||||
import torch
|
||||
|
||||
import transformers
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = self.has_attentions
|
||||
|
||||
# Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency
|
||||
# of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`.
|
||||
# TODO: Use a uniform value for all models, make sure all tests pass without this processing, and remove it.
|
||||
self._make_attention_mask_non_null(inputs_dict)
|
||||
|
||||
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
tf_model = model_class(config)
|
||||
pt_model = pt_model_class(config)
|
||||
|
||||
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with model => model functions
|
||||
tf_model = transformers.load_pytorch_model_in_tf2_model(
|
||||
tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
pt_model = transformers.load_tf2_model_in_pytorch_model(
|
||||
pt_model, tf_model, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
|
||||
# Original test: check without `labels`
|
||||
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
|
||||
torch.save(pt_model.state_dict(), pt_checkpoint_path)
|
||||
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
|
||||
tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
|
||||
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
|
||||
tf_model.save_weights(tf_checkpoint_path)
|
||||
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
|
||||
pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
|
||||
# Original test: check without `labels`
|
||||
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
|
@ -458,6 +516,62 @@ class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
# TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC
|
||||
pass
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
|
||||
# We override the base test here to skip loss calculation for Hubert models because the loss is massive with
|
||||
# the default labels and frequently overflows to inf or exceeds numerical tolerances between TF/PT
|
||||
import torch
|
||||
|
||||
import transformers
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = self.has_attentions
|
||||
|
||||
# Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency
|
||||
# of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`.
|
||||
# TODO: Use a uniform value for all models, make sure all tests pass without this processing, and remove it.
|
||||
self._make_attention_mask_non_null(inputs_dict)
|
||||
|
||||
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
tf_model = model_class(config)
|
||||
pt_model = pt_model_class(config)
|
||||
|
||||
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with model => model functions
|
||||
tf_model = transformers.load_pytorch_model_in_tf2_model(
|
||||
tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
pt_model = transformers.load_tf2_model_in_pytorch_model(
|
||||
pt_model, tf_model, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
|
||||
# Original test: check without `labels`
|
||||
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
|
||||
torch.save(pt_model.state_dict(), pt_checkpoint_path)
|
||||
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
|
||||
tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
|
||||
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
|
||||
tf_model.save_weights(tf_checkpoint_path)
|
||||
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
|
||||
pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
|
||||
# Original test: check without `labels`
|
||||
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFHubertUtilsTest(unittest.TestCase):
|
||||
|
|
|
|||
|
|
@ -19,6 +19,8 @@ import glob
|
|||
import inspect
|
||||
import math
|
||||
import multiprocessing
|
||||
import os
|
||||
import tempfile
|
||||
import traceback
|
||||
import unittest
|
||||
|
||||
|
|
@ -31,6 +33,7 @@ from transformers import Wav2Vec2Config, is_tf_available
|
|||
from transformers.testing_utils import (
|
||||
CaptureLogger,
|
||||
is_flaky,
|
||||
is_pt_tf_cross_test,
|
||||
require_librosa,
|
||||
require_pyctcdecode,
|
||||
require_tf,
|
||||
|
|
@ -397,6 +400,62 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Test
|
|||
# TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC
|
||||
pass
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
|
||||
# We override the base test here to skip loss calculation for Wav2Vec2 models because the loss is massive with
|
||||
# the default labels and frequently overflows to inf or exceeds numerical tolerances between TF/PT
|
||||
import torch
|
||||
|
||||
import transformers
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = self.has_attentions
|
||||
|
||||
# Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency
|
||||
# of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`.
|
||||
# TODO: Use a uniform value for all models, make sure all tests pass without this processing, and remove it.
|
||||
self._make_attention_mask_non_null(inputs_dict)
|
||||
|
||||
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
tf_model = model_class(config)
|
||||
pt_model = pt_model_class(config)
|
||||
|
||||
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with model => model functions
|
||||
tf_model = transformers.load_pytorch_model_in_tf2_model(
|
||||
tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
pt_model = transformers.load_tf2_model_in_pytorch_model(
|
||||
pt_model, tf_model, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
|
||||
# Original test: check without `labels`
|
||||
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
|
||||
torch.save(pt_model.state_dict(), pt_checkpoint_path)
|
||||
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
|
||||
tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
|
||||
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
|
||||
tf_model.save_weights(tf_checkpoint_path)
|
||||
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
|
||||
pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
|
||||
# Original test: check without `labels`
|
||||
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
|
@ -524,6 +583,62 @@ class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
# TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC
|
||||
pass
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
|
||||
# We override the base test here to skip loss calculation for Wav2Vec2 models because the loss is massive with
|
||||
# the default labels and frequently overflows to inf or exceeds numerical tolerances between TF/PT
|
||||
import torch
|
||||
|
||||
import transformers
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = self.has_attentions
|
||||
|
||||
# Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency
|
||||
# of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`.
|
||||
# TODO: Use a uniform value for all models, make sure all tests pass without this processing, and remove it.
|
||||
self._make_attention_mask_non_null(inputs_dict)
|
||||
|
||||
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
tf_model = model_class(config)
|
||||
pt_model = pt_model_class(config)
|
||||
|
||||
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with model => model functions
|
||||
tf_model = transformers.load_pytorch_model_in_tf2_model(
|
||||
tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
pt_model = transformers.load_tf2_model_in_pytorch_model(
|
||||
pt_model, tf_model, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
|
||||
# Original test: check without `labels`
|
||||
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
|
||||
torch.save(pt_model.state_dict(), pt_checkpoint_path)
|
||||
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
|
||||
tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
|
||||
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
|
||||
tf_model.save_weights(tf_checkpoint_path)
|
||||
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
|
||||
pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
|
||||
# Original test: check without `labels`
|
||||
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFWav2Vec2UtilsTest(unittest.TestCase):
|
||||
|
|
|
|||
|
|
@ -2030,7 +2030,7 @@ class ModelTesterMixin:
|
|||
|
||||
# For some models (e.g. base models), there is no label returned.
|
||||
# Set the input dict to `None` to avoid check outputs twice for the same input dicts.
|
||||
if set(pt_inputs_dict_with_labels.keys()).symmetric_difference(pt_inputs_dict.keys()):
|
||||
if not set(pt_inputs_dict_with_labels.keys()).symmetric_difference(pt_inputs_dict.keys()):
|
||||
pt_inputs_dict_with_labels = None
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with model => model functions
|
||||
|
|
|
|||
|
|
@ -699,7 +699,7 @@ class TFModelTesterMixin:
|
|||
|
||||
# For some models (e.g. base models), there is no label returned.
|
||||
# Set the input dict to `None` to avoid check outputs twice for the same input dicts.
|
||||
if set(tf_inputs_dict_with_labels.keys()).symmetric_difference(tf_inputs_dict.keys()):
|
||||
if not set(tf_inputs_dict_with_labels.keys()).symmetric_difference(tf_inputs_dict.keys()):
|
||||
tf_inputs_dict_with_labels = None
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with model => model functions
|
||||
|
|
|
|||
Loading…
Reference in a new issue