Add test to ensure models can take int64 inputs (#17210)

* Add test to ensure models can take int64 inputs

* is_integer is an attribute, not a method

* Fix test when some inputs aren't tensors

* Add casts to blenderbot and blenderbot-small

* Add casts to the other failing models
This commit is contained in:
Matt 2022-05-12 16:09:25 +01:00 committed by GitHub
parent 5294fa12ee
commit f04257fdbc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 34 additions and 11 deletions

View file

@ -1287,7 +1287,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
if labels is not None:
labels = tf.where(
labels == self.config.pad_token_id,
tf.fill(shape_list(labels), -100),
tf.cast(tf.fill(shape_list(labels), -100), labels.dtype),
labels,
)
use_cache = False

View file

@ -1265,7 +1265,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
if labels is not None:
labels = tf.where(
labels == self.config.pad_token_id,
tf.fill(shape_list(labels), -100),
tf.cast(tf.fill(shape_list(labels), -100), labels.dtype),
labels,
)
use_cache = False

View file

@ -182,8 +182,8 @@ def get_masks(slen, lengths, causal, padding_mask=None):
mask = padding_mask
else:
# assert lengths.max().item() <= slen
alen = tf.range(slen)
mask = tf.math.less(alen, tf.expand_dims(lengths, axis=1))
alen = tf.range(slen, dtype=lengths.dtype)
mask = alen < tf.expand_dims(lengths, axis=1)
# attention mask is the same as mask, or triangular inferior attention (causal)
if causal:

View file

@ -1300,7 +1300,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
if labels is not None:
labels = tf.where(
labels == self.config.pad_token_id,
tf.fill(shape_list(labels), -100),
tf.cast(tf.fill(shape_list(labels), -100), labels.dtype),
labels,
)
use_cache = False

View file

@ -1317,7 +1317,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
if labels is not None:
labels = tf.where(
labels == self.config.pad_token_id,
tf.fill(shape_list(labels), -100),
tf.cast(tf.fill(shape_list(labels), -100), labels.dtype),
labels,
)
use_cache = False

View file

@ -1726,7 +1726,10 @@ class ProductIndexMap(IndexMap):
raise ValueError("outer_index.batch_dims and inner_index.batch_dims " "must be the same.")
super(ProductIndexMap, self).__init__(
indices=(inner_index.indices + outer_index.indices * inner_index.num_segments),
indices=(
inner_index.indices
+ outer_index.indices * tf.cast(inner_index.num_segments, inner_index.indices.dtype)
),
num_segments=inner_index.num_segments * outer_index.num_segments,
batch_dims=inner_index.batch_dims,
)
@ -1785,7 +1788,7 @@ def flatten(index, name="segmented_flatten"):
for _ in range(index.batch_dims, index.indices.shape.rank):
offset = tf.expand_dims(offset, -1)
indices = offset + index.indices
indices = tf.cast(offset, index.indices.dtype) + index.indices
return IndexMap(indices=tf.reshape(indices, [-1]), num_segments=index.num_segments * batch_size, batch_dims=0)

View file

@ -111,7 +111,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
@staticmethod
def _gather_logprob(logprob, target):
lp_size = shape_list(logprob)
r = tf.range(lp_size[0])
r = tf.range(lp_size[0], dtype=target.dtype)
idx = tf.stack([r, target], 1)
return tf.gather_nd(logprob, idx)

View file

@ -92,8 +92,8 @@ def get_masks(slen, lengths, causal, padding_mask=None):
mask = padding_mask
else:
# assert lengths.max().item() <= slen
alen = tf.range(slen)
mask = tf.math.less(alen, tf.expand_dims(lengths, axis=1))
alen = tf.range(slen, dtype=lengths.dtype)
mask = alen < tf.expand_dims(lengths, axis=1)
# attention mask is the same as mask, or triangular inferior attention (causal)
if causal:

View file

@ -1372,6 +1372,26 @@ class TFModelTesterMixin:
val_loss2 = history2.history["val_loss"][0]
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
def test_int64_inputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
prepared_for_class = self._prepare_for_class(
inputs_dict.copy(),
model_class,
return_labels=True if "labels" in inspect.signature(model_class.call).parameters.keys() else False,
)
if not any(
[tensor.dtype.is_integer for tensor in prepared_for_class.values() if isinstance(tensor, tf.Tensor)]
):
return # No integer inputs means no need for this test
prepared_for_class = {
key: tf.cast(tensor, tf.int64) if isinstance(tensor, tf.Tensor) and tensor.dtype.is_integer else tensor
for key, tensor in prepared_for_class.items()
}
model = model_class(config)
model(**prepared_for_class) # No assertion, we're just checking this doesn't throw an error
def test_generate_with_headmasking(self):
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()