From fc95386ea12fc11942cc7f2a4f99ef9602d774ef Mon Sep 17 00:00:00 2001 From: Cole Howard Date: Wed, 7 Dec 2022 09:05:39 -0800 Subject: [PATCH] Add TFBartForSequenceClassification (#20570) * read to load * base functionality * revert init * fix dummy data * moving right along * moving right along * finally * cleanup * pull out comment * add test * update docstring for main class * flake comments and rewriting copies from make repo-consistency` * remove irrelevant differences/accidental spaces * put copies back after space removals * mid * final test pass * stray comment * update test file * update test file * fixup * black * missed * black missed one more * sytle * add doc update * fix order of output class * comment * Revert "comment" This reverts commit 03f86b6948808461939cc8ad4ad74305dfb67700. * remove redundant function, and redundant reshape * move change out of common * style * put common spaces back * reorder kwargs in output * doc style --- docs/source/en/model_doc/bart.mdx | 5 + src/transformers/__init__.py | 11 +- .../convert_pytorch_checkpoint_to_tf2.py | 2 + src/transformers/modeling_tf_outputs.py | 4 + src/transformers/modeling_tf_utils.py | 2 +- .../models/auto/modeling_tf_auto.py | 1 + src/transformers/models/bart/__init__.py | 14 +- .../models/bart/modeling_tf_bart.py | 159 +++++++++++++++++- src/transformers/utils/dummy_tf_objects.py | 7 + tests/models/bart/test_modeling_tf_bart.py | 142 +++++++++++++++- 10 files changed, 338 insertions(+), 9 deletions(-) diff --git a/docs/source/en/model_doc/bart.mdx b/docs/source/en/model_doc/bart.mdx index e2d788c8c..9fae10212 100644 --- a/docs/source/en/model_doc/bart.mdx +++ b/docs/source/en/model_doc/bart.mdx @@ -157,6 +157,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h [[autodoc]] TFBartForConditionalGeneration - call +## TFBartForSequenceClassification + +[[autodoc]] TFBartForSequenceClassification + - call + ## FlaxBartModel [[autodoc]] FlaxBartModel diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index a0bb4f9dd..dce1d73e9 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2513,7 +2513,9 @@ else: "TFAutoModelWithLMHead", ] ) - _import_structure["models.bart"].extend(["TFBartForConditionalGeneration", "TFBartModel", "TFBartPretrainedModel"]) + _import_structure["models.bart"].extend( + ["TFBartForConditionalGeneration", "TFBartForSequenceClassification", "TFBartModel", "TFBartPretrainedModel"] + ) _import_structure["models.bert"].extend( [ "TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -5402,7 +5404,12 @@ if TYPE_CHECKING: TFAutoModelForVision2Seq, TFAutoModelWithLMHead, ) - from .models.bart import TFBartForConditionalGeneration, TFBartModel, TFBartPretrainedModel + from .models.bart import ( + TFBartForConditionalGeneration, + TFBartForSequenceClassification, + TFBartModel, + TFBartPretrainedModel, + ) from .models.bert import ( TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST, TFBertEmbeddings, diff --git a/src/transformers/convert_pytorch_checkpoint_to_tf2.py b/src/transformers/convert_pytorch_checkpoint_to_tf2.py index 6a05e40f0..62a071dd3 100755 --- a/src/transformers/convert_pytorch_checkpoint_to_tf2.py +++ b/src/transformers/convert_pytorch_checkpoint_to_tf2.py @@ -58,6 +58,7 @@ from . import ( T5Config, TFAlbertForPreTraining, TFBartForConditionalGeneration, + TFBartForSequenceClassification, TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification, @@ -136,6 +137,7 @@ MODEL_CLASSES = { "bart": ( BartConfig, TFBartForConditionalGeneration, + TFBartForSequenceClassification, BartForConditionalGeneration, BART_PRETRAINED_MODEL_ARCHIVE_LIST, ), diff --git a/src/transformers/modeling_tf_outputs.py b/src/transformers/modeling_tf_outputs.py index efb241208..0fed3e785 100644 --- a/src/transformers/modeling_tf_outputs.py +++ b/src/transformers/modeling_tf_outputs.py @@ -623,6 +623,9 @@ class TFSeq2SeqSequenceClassifierOutput(ModelOutput): Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the self-attention heads. + cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)` encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder of the model. encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): @@ -643,6 +646,7 @@ class TFSeq2SeqSequenceClassifierOutput(ModelOutput): past_key_values: Optional[List[tf.Tensor]] = None decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None decoder_attentions: Optional[Tuple[tf.Tensor]] = None + cross_attentions: Optional[Tuple[tf.Tensor]] = None encoder_last_hidden_state: Optional[tf.Tensor] = None encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None encoder_attentions: Optional[Tuple[tf.Tensor]] = None diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index a642e883b..5e3c49290 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -1190,7 +1190,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu return self.serving_output(output) - def serving_output(output): + def serving_output(self, output): """ Prepare the output of the saved model. Each model must implement this function. diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index 8bb7b5595..63934d67c 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -268,6 +268,7 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Sequence Classification mapping ("albert", "TFAlbertForSequenceClassification"), + ("bart", "TFBartForSequenceClassification"), ("bert", "TFBertForSequenceClassification"), ("camembert", "TFCamembertForSequenceClassification"), ("convbert", "TFConvBertForSequenceClassification"), diff --git a/src/transformers/models/bart/__init__.py b/src/transformers/models/bart/__init__.py index ec1010f7b..99ce16525 100644 --- a/src/transformers/models/bart/__init__.py +++ b/src/transformers/models/bart/__init__.py @@ -63,7 +63,12 @@ try: except OptionalDependencyNotAvailable: pass else: - _import_structure["modeling_tf_bart"] = ["TFBartForConditionalGeneration", "TFBartModel", "TFBartPretrainedModel"] + _import_structure["modeling_tf_bart"] = [ + "TFBartForConditionalGeneration", + "TFBartForSequenceClassification", + "TFBartModel", + "TFBartPretrainedModel", + ] try: if not is_flax_available(): @@ -116,7 +121,12 @@ if TYPE_CHECKING: except OptionalDependencyNotAvailable: pass else: - from .modeling_tf_bart import TFBartForConditionalGeneration, TFBartModel, TFBartPretrainedModel + from .modeling_tf_bart import ( + TFBartForConditionalGeneration, + TFBartForSequenceClassification, + TFBartModel, + TFBartPretrainedModel, + ) try: if not is_flax_available(): diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py index 7dbe3384e..19204f7e8 100644 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ b/src/transformers/models/bart/modeling_tf_bart.py @@ -27,6 +27,7 @@ from ...modeling_tf_outputs import ( TFBaseModelOutputWithPastAndCrossAttentions, TFSeq2SeqLMOutput, TFSeq2SeqModelOutput, + TFSeq2SeqSequenceClassifierOutput, ) # Public API @@ -35,6 +36,7 @@ from ...modeling_tf_utils import ( TFCausalLanguageModelingLoss, TFModelInputType, TFPreTrainedModel, + TFSequenceClassificationLoss, keras_serializable, unpack_inputs, ) @@ -460,6 +462,24 @@ class TFBartDecoderLayer(tf.keras.layers.Layer): ) +class TFBartClassificationHead(tf.keras.layers.Layer): + """Head for sentence-level classification tasks.""" + + def __init__(self, inner_dim: int, num_classes: int, pooler_dropout: float, name: str, **kwargs): + super().__init__(name=name, **kwargs) + self.dense = tf.keras.layers.Dense(inner_dim, name="dense") + self.dropout = tf.keras.layers.Dropout(pooler_dropout) + self.out_proj = tf.keras.layers.Dense(num_classes, name="out_proj") + + def call(self, inputs): + hidden_states = self.dropout(inputs) + hidden_states = self.dense(hidden_states) + hidden_states = tf.keras.activations.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + class TFBartPretrainedModel(TFPreTrainedModel): config_class = BartConfig base_model_prefix = "model" @@ -726,7 +746,6 @@ class TFBartEncoder(tf.keras.layers.Layer): return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ - if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: @@ -1465,3 +1484,141 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:], ) return reordered_past + + +@add_start_docstrings( + """ + Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE + tasks. + """, + BART_START_DOCSTRING, +) +class TFBartForSequenceClassification(TFBartPretrainedModel, TFSequenceClassificationLoss): + @property + def dummy_inputs(self): + pad_token = self.config.pad_token_id + input_ids = tf.constant([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]]) + dummy_inputs = { + "attention_mask": tf.cast(tf.math.not_equal(input_ids, (pad_token)), dtype=tf.int32), + "input_ids": input_ids, + } + return dummy_inputs + + def __init__(self, config: BartConfig, load_weight_prefix=None, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.model = TFBartMainLayer(config, load_weight_prefix=load_weight_prefix, name="model") + self.classification_head = TFBartClassificationHead( + config.d_model, config.num_labels, config.classifier_dropout, name="classification_head" + ) + + @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFSeq2SeqSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) + @unpack_inputs + def call( + self, + input_ids: Optional[TFModelInputType] = None, + attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, + decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + decoder_position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, + head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_outputs: Optional[TFBaseModelOutput] = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[tf.Tensor] = None, + training: Optional[bool] = False, + ) -> Union[TFSeq2SeqSequenceClassifierOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + if input_ids is None and inputs_embeds is not None: + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__}" + ) + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + last_hidden_state = outputs[0] + eos_mask = tf.equal(input_ids, self.config.eos_token_id) + # out the rows with False where present. Then verify all the final + # entries are True + self_masked = tf.reshape(tf.boolean_mask(eos_mask, eos_mask), (tf.shape(input_ids)[0], -1)) + tf.Assert(tf.reduce_all(self_masked[:, -1]), ["All examples must have the same number of tokens."]) + + masked = tf.reshape( + tf.boolean_mask(last_hidden_state, eos_mask), + (tf.shape(input_ids)[0], tf.shape(self_masked)[1], tf.shape(last_hidden_state)[-1]), + ) + + sentence_representation = masked[:, -1, :] + logits = self.classification_head(sentence_representation) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TFSeq2SeqSequenceClassifierOutput( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + def serving_output(self, output): + logits = tf.convert_to_tensor(output.logits) + pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None + dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None + dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None + enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None + enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None + + return TFSeq2SeqSequenceClassifierOutput( + logits=logits, + past_key_values=pkv, + decoder_hidden_states=dec_hs, + decoder_attentions=dec_attns, + cross_attentions=cross_attns, + encoder_last_hidden_state=output.encoder_last_hidden_state, + encoder_hidden_states=enc_hs, + encoder_attentions=enc_attns, + ) diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index d16a75591..a24e53c50 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -449,6 +449,13 @@ class TFBartForConditionalGeneration(metaclass=DummyObject): requires_backends(self, ["tf"]) +class TFBartForSequenceClassification(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + class TFBartModel(metaclass=DummyObject): _backends = ["tf"] diff --git a/tests/models/bart/test_modeling_tf_bart.py b/tests/models/bart/test_modeling_tf_bart.py index e4c4f43c4..1b3682a76 100644 --- a/tests/models/bart/test_modeling_tf_bart.py +++ b/tests/models/bart/test_modeling_tf_bart.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy +import tempfile import unittest import numpy as np @@ -29,7 +31,7 @@ from ...utils.test_modeling_tf_core import TFCoreModelTesterMixin if is_tf_available(): import tensorflow as tf - from transformers import TFBartForConditionalGeneration, TFBartModel + from transformers import TFBartForConditionalGeneration, TFBartForSequenceClassification, TFBartModel @require_tf @@ -76,7 +78,13 @@ class TFBartModelTester: self.bos_token_id = bos_token_id def prepare_config_and_inputs_for_common(self): - input_ids = ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size) + # Ids are clipped to avoid "beginng of sequence", "end of sequence", and "pad" tokens + input_ids = tf.clip_by_value( + ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size), + clip_value_min=self.eos_token_id + 1, + clip_value_max=self.vocab_size + 1, + ) + # Explicity add "end of sequence" to the inputs eos_tensor = tf.expand_dims(tf.constant([self.eos_token_id] * self.batch_size), 1) input_ids = tf.concat([input_ids, eos_tensor], axis=1) @@ -181,7 +189,9 @@ def prepare_bart_inputs_dict( @require_tf class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestCase): - all_model_classes = (TFBartForConditionalGeneration, TFBartModel) if is_tf_available() else () + all_model_classes = ( + (TFBartForConditionalGeneration, TFBartForSequenceClassification, TFBartModel) if is_tf_available() else () + ) all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else () is_encoder_decoder = True test_pruning = False @@ -228,6 +238,119 @@ class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC def test_onnx_compliancy(self): pass + # TFBartForSequenceClassification does not support inputs_embeds + def test_inputs_embeds(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in (TFBartForConditionalGeneration, TFBartModel): + model = model_class(config) + + inputs = copy.deepcopy(inputs_dict) + + if not self.is_encoder_decoder: + input_ids = inputs["input_ids"] + del inputs["input_ids"] + else: + encoder_input_ids = inputs["input_ids"] + decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids) + del inputs["input_ids"] + inputs.pop("decoder_input_ids", None) + + if not self.is_encoder_decoder: + inputs["inputs_embeds"] = model.get_input_embeddings()(input_ids) + else: + inputs["inputs_embeds"] = model.get_input_embeddings()(encoder_input_ids) + inputs["decoder_inputs_embeds"] = model.get_input_embeddings()(decoder_input_ids) + + inputs = self._prepare_for_class(inputs, model_class) + + model(inputs) + + # TFBartForSequenceClassification does not support inputs_embeds + @slow + def test_graph_mode_with_inputs_embeds(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in (TFBartForConditionalGeneration, TFBartModel): + model = model_class(config) + + inputs = copy.deepcopy(inputs_dict) + + if not self.is_encoder_decoder: + input_ids = inputs["input_ids"] + del inputs["input_ids"] + else: + encoder_input_ids = inputs["input_ids"] + decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids) + del inputs["input_ids"] + inputs.pop("decoder_input_ids", None) + + if not self.is_encoder_decoder: + inputs["inputs_embeds"] = model.get_input_embeddings()(input_ids) + else: + inputs["inputs_embeds"] = model.get_input_embeddings()(encoder_input_ids) + inputs["decoder_inputs_embeds"] = model.get_input_embeddings()(decoder_input_ids) + + inputs = self._prepare_for_class(inputs, model_class) + + @tf.function + def run_in_graph_mode(): + return model(inputs) + + outputs = run_in_graph_mode() + self.assertIsNotNone(outputs) + + @slow + def test_save_load_after_resize_token_embeddings(self): + # Custom version of this test to ensure "end of sequence" tokens are present throughout + if not self.test_resize_embeddings: + return + config, original_inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + # create a model with resized (expended) embeddings + new_tokens_size = 10 + old_total_size = config.vocab_size + new_total_size = old_total_size + new_tokens_size + model = model_class(config=copy.deepcopy(config)) # `resize_token_embeddings` mutates `config` + model(model.dummy_inputs) # builds the embeddings layer + model.resize_token_embeddings(new_total_size) + + # fetch the output for an input exclusively made of new members of the vocabulary + inputs_dict = copy.deepcopy(original_inputs_dict) + ids_feat_name = None + if "input_ids" in inputs_dict: + ids_feat_name = "input_ids" + elif "decoder_input_ids" in inputs_dict: + ids_feat_name = "decoder_input_ids" + else: + assert False, "No input ids feature found in the inputs dict" + + new_vocab_input_ids = ids_tensor(inputs_dict[ids_feat_name].shape, new_tokens_size) + new_vocab_input_ids += old_total_size + + # Replace last id with EOS token + new_vocab_input_ids = new_vocab_input_ids[:, :-1] + new_vocab_input_ids = tf.concat( + [new_vocab_input_ids, tf.ones((tf.shape(new_vocab_input_ids)[0], 1), dtype=tf.int32) * 2], axis=1 + ) + + inputs_dict[ids_feat_name] = new_vocab_input_ids + if "input_ids" in inputs_dict: + inputs_dict["input_ids"] = new_vocab_input_ids + if "decoder_input_ids" in inputs_dict: + inputs_dict["decoder_input_ids"] = new_vocab_input_ids + prepared_inputs = self._prepare_for_class(inputs_dict, model_class) + outputs = model(**prepared_inputs) + + # save and load the model + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname, saved_model=False) + model = model_class.from_pretrained(tmpdirname) + restored_model_outputs = model(**prepared_inputs) + + # check that the output for the restored model is the same + self.assert_outputs_same(restored_model_outputs, outputs) + def _long_tensor(tok_lst): return tf.constant(tok_lst, dtype=tf.int32) @@ -286,6 +409,19 @@ class TFBartHeadTests(unittest.TestCase): self.assertEqual(outputs.logits.shape, expected_shape) +@require_tf +class TFBartForSequenceClassificationTest(unittest.TestCase): + def test_model_fails_for_uneven_eos_tokens(self): + config = BartConfig(eos_token_id=2) + model = TFBartForSequenceClassification(config) + inputs = { + "input_ids": tf.constant([[1, 2, 2, 2], [1, 3, 2, 2], [2, 2, 3, 3]]), + "attention_mask": tf.constant([[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]), + } + with self.assertRaises(tf.errors.InvalidArgumentError): + model(inputs) + + @slow @require_tf class TFBartModelIntegrationTest(unittest.TestCase):