diff --git a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py index 7fa910a3e..b4bceee3e 100644 --- a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py @@ -1287,7 +1287,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal if labels is not None: labels = tf.where( labels == self.config.pad_token_id, - tf.fill(shape_list(labels), -100), + tf.cast(tf.fill(shape_list(labels), -100), labels.dtype), labels, ) use_cache = False diff --git a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py index 612755882..95078af4b 100644 --- a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py @@ -1265,7 +1265,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel if labels is not None: labels = tf.where( labels == self.config.pad_token_id, - tf.fill(shape_list(labels), -100), + tf.cast(tf.fill(shape_list(labels), -100), labels.dtype), labels, ) use_cache = False diff --git a/src/transformers/models/flaubert/modeling_tf_flaubert.py b/src/transformers/models/flaubert/modeling_tf_flaubert.py index d4bd3f53f..bc4921622 100644 --- a/src/transformers/models/flaubert/modeling_tf_flaubert.py +++ b/src/transformers/models/flaubert/modeling_tf_flaubert.py @@ -182,8 +182,8 @@ def get_masks(slen, lengths, causal, padding_mask=None): mask = padding_mask else: # assert lengths.max().item() <= slen - alen = tf.range(slen) - mask = tf.math.less(alen, tf.expand_dims(lengths, axis=1)) + alen = tf.range(slen, dtype=lengths.dtype) + mask = alen < tf.expand_dims(lengths, axis=1) # attention mask is the same as mask, or triangular inferior attention (causal) if causal: diff --git a/src/transformers/models/mbart/modeling_tf_mbart.py b/src/transformers/models/mbart/modeling_tf_mbart.py index b31ac1bd6..b7de8be6e 100644 --- a/src/transformers/models/mbart/modeling_tf_mbart.py +++ b/src/transformers/models/mbart/modeling_tf_mbart.py @@ -1300,7 +1300,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo if labels is not None: labels = tf.where( labels == self.config.pad_token_id, - tf.fill(shape_list(labels), -100), + tf.cast(tf.fill(shape_list(labels), -100), labels.dtype), labels, ) use_cache = False diff --git a/src/transformers/models/pegasus/modeling_tf_pegasus.py b/src/transformers/models/pegasus/modeling_tf_pegasus.py index be2539b3a..2a1b7994b 100644 --- a/src/transformers/models/pegasus/modeling_tf_pegasus.py +++ b/src/transformers/models/pegasus/modeling_tf_pegasus.py @@ -1317,7 +1317,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua if labels is not None: labels = tf.where( labels == self.config.pad_token_id, - tf.fill(shape_list(labels), -100), + tf.cast(tf.fill(shape_list(labels), -100), labels.dtype), labels, ) use_cache = False diff --git a/src/transformers/models/tapas/modeling_tf_tapas.py b/src/transformers/models/tapas/modeling_tf_tapas.py index d2da06446..29cb63c3a 100644 --- a/src/transformers/models/tapas/modeling_tf_tapas.py +++ b/src/transformers/models/tapas/modeling_tf_tapas.py @@ -1726,7 +1726,10 @@ class ProductIndexMap(IndexMap): raise ValueError("outer_index.batch_dims and inner_index.batch_dims " "must be the same.") super(ProductIndexMap, self).__init__( - indices=(inner_index.indices + outer_index.indices * inner_index.num_segments), + indices=( + inner_index.indices + + outer_index.indices * tf.cast(inner_index.num_segments, inner_index.indices.dtype) + ), num_segments=inner_index.num_segments * outer_index.num_segments, batch_dims=inner_index.batch_dims, ) @@ -1785,7 +1788,7 @@ def flatten(index, name="segmented_flatten"): for _ in range(index.batch_dims, index.indices.shape.rank): offset = tf.expand_dims(offset, -1) - indices = offset + index.indices + indices = tf.cast(offset, index.indices.dtype) + index.indices return IndexMap(indices=tf.reshape(indices, [-1]), num_segments=index.num_segments * batch_size, batch_dims=0) diff --git a/src/transformers/models/transfo_xl/modeling_tf_transfo_xl_utilities.py b/src/transformers/models/transfo_xl/modeling_tf_transfo_xl_utilities.py index af95f348e..dcfa84d0f 100644 --- a/src/transformers/models/transfo_xl/modeling_tf_transfo_xl_utilities.py +++ b/src/transformers/models/transfo_xl/modeling_tf_transfo_xl_utilities.py @@ -111,7 +111,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer): @staticmethod def _gather_logprob(logprob, target): lp_size = shape_list(logprob) - r = tf.range(lp_size[0]) + r = tf.range(lp_size[0], dtype=target.dtype) idx = tf.stack([r, target], 1) return tf.gather_nd(logprob, idx) diff --git a/src/transformers/models/xlm/modeling_tf_xlm.py b/src/transformers/models/xlm/modeling_tf_xlm.py index 24d32f798..fa3a54b6c 100644 --- a/src/transformers/models/xlm/modeling_tf_xlm.py +++ b/src/transformers/models/xlm/modeling_tf_xlm.py @@ -92,8 +92,8 @@ def get_masks(slen, lengths, causal, padding_mask=None): mask = padding_mask else: # assert lengths.max().item() <= slen - alen = tf.range(slen) - mask = tf.math.less(alen, tf.expand_dims(lengths, axis=1)) + alen = tf.range(slen, dtype=lengths.dtype) + mask = alen < tf.expand_dims(lengths, axis=1) # attention mask is the same as mask, or triangular inferior attention (causal) if causal: diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 0d38713e0..6edc6b20c 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -1372,6 +1372,26 @@ class TFModelTesterMixin: val_loss2 = history2.history["val_loss"][0] self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3)) + def test_int64_inputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + prepared_for_class = self._prepare_for_class( + inputs_dict.copy(), + model_class, + return_labels=True if "labels" in inspect.signature(model_class.call).parameters.keys() else False, + ) + if not any( + [tensor.dtype.is_integer for tensor in prepared_for_class.values() if isinstance(tensor, tf.Tensor)] + ): + return # No integer inputs means no need for this test + + prepared_for_class = { + key: tf.cast(tensor, tf.int64) if isinstance(tensor, tf.Tensor) and tensor.dtype.is_integer else tensor + for key, tensor in prepared_for_class.items() + } + model = model_class(config) + model(**prepared_for_class) # No assertion, we're just checking this doesn't throw an error + def test_generate_with_headmasking(self): attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"] config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()