diff --git a/src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py b/src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py index 6936030af..996eb7f42 100644 --- a/src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py +++ b/src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py @@ -59,6 +59,7 @@ class TFPositionalEmbedding(tf.keras.layers.Layer): self.inv_freq = 1 / (10000 ** (tf.range(0, demb, 2.0) / demb)) def call(self, pos_seq, bsz=None): + self.inv_freq = tf.cast(self.inv_freq, dtype=pos_seq.dtype) sinusoid_inp = tf.einsum("i,j->ij", pos_seq, self.inv_freq) pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1) @@ -186,6 +187,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer): qlen, rlen, bsz = shape_list(w)[0], shape_list(r)[0], shape_list(w)[1] if mems is not None: + mems = tf.cast(mems, dtype=w.dtype) cat = tf.concat([mems, w], 0) if self.pre_lnorm: w_heads = self.qkv_net(self.layer_norm(cat)) @@ -227,7 +229,8 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer): # compute attention probability if attn_mask is not None: attn_mask_t = attn_mask[:, :, None, None] - attn_score = attn_score * (1 - attn_mask_t) - 1e30 * attn_mask_t + attn_mask_t = tf.cast(attn_mask_t, dtype=attn_score.dtype) + attn_score = attn_score * (1.0 - attn_mask_t) - 1e30 * attn_mask_t # [qlen x klen x bsz x n_head] attn_prob = tf.nn.softmax(attn_score, axis=1) @@ -313,6 +316,27 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer): return outputs +class TFTransfoEmbeddings(tf.keras.layers.Layer): + def __init__(self, vocab_size, emb_size, init_std, **kwargs): + super().__init__(**kwargs) + + self.vocab_size = vocab_size + self.emb_size = emb_size + self.init_std = init_std + + def build(self, input_shape): + self.weight = self.add_weight( + shape=(self.vocab_size, self.emb_size), + initializer=get_initializer(self.init_std), + name="embeddings", + ) + + super().build(input_shape) + + def call(self, inputs): + return tf.gather(self.weight, inputs) + + class TFAdaptiveEmbedding(tf.keras.layers.Layer): def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, init_std=0.02, sample_softmax=False, **kwargs): super().__init__(**kwargs) @@ -331,6 +355,7 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer): self.emb_layers = [] self.emb_projs = [] + if div_val == 1: raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint else: @@ -338,10 +363,10 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer): l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] d_emb_i = d_embed // (div_val ** i) self.emb_layers.append( - tf.keras.layers.Embedding( + TFTransfoEmbeddings( r_idx - l_idx, d_emb_i, - embeddings_initializer=get_initializer(init_std), + init_std, name="emb_layers_._{}".format(i), ) ) @@ -357,6 +382,7 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer): name="emb_projs_._{}".format(i), ) ) + super().build(input_shape) def call(self, inp): @@ -374,8 +400,10 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer): emb_i = self.emb_layers[i](inp_i) emb_i = tf.einsum("id,de->ie", emb_i, self.emb_projs[i]) - mask_idx = tf.cast(tf.where(mask_i), dtype=tf.int64) - emb_flat += tf.scatter_nd(mask_idx, emb_i, tf.cast(shape_list(emb_flat), dtype=tf.int64)) + mask_idx = tf.where(mask_i) + scatter = tf.scatter_nd(mask_idx, emb_i, shape_list(emb_flat)) + emb_flat = tf.cast(emb_flat, dtype=scatter.dtype) + emb_flat += scatter embed_shape = shape_list(inp) + [self.d_proj] embed = tf.reshape(emb_flat, embed_shape) @@ -501,7 +529,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): end_idx = mlen + tf.math.maximum(0, qlen) beg_idx = tf.math.maximum(0, end_idx - tf.convert_to_tensor(self.mem_len)) for i in range(len(hids)): - + mems[i] = tf.cast(mems[i], dtype=hids[i].dtype) cat = tf.concat([mems[i], hids[i]], axis=0) tf.stop_gradient(cat) new_mems.append(cat[beg_idx:end_idx]) @@ -1113,7 +1141,6 @@ class TFTransfoXLForSequenceClassification(TFTransfoXLPreTrainedModel, TFSequenc hidden_states = transformer_outputs[0] logits = self.score(hidden_states) - logits_shape = shape_list(logits) in_logits = None if self.config.pad_token_id is None: sequence_lengths = -1 @@ -1121,22 +1148,16 @@ class TFTransfoXLForSequenceClassification(TFTransfoXLPreTrainedModel, TFSequenc if inputs["input_ids"] is not None: sequence_lengths = ( tf.reduce_sum( - tf.cast(tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), tf.int32), + tf.cast( + tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), + dtype=inputs["input_ids"].dtype, + ), -1, keepdims=False, ) - 1 ) - - def get_seq_element(sequence_position, input_batch): - return tf.strided_slice( - input_batch, [sequence_position, 0], [sequence_position + 1, input_batch.shape[-1]], [1, 1] - ) - - result = tf.map_fn( - fn=lambda t: get_seq_element(t[0], t[1]), elems=[sequence_lengths, logits], dtype="float" - ) - in_logits = tf.reshape(result, [logits_shape[0], logits_shape[-1]]) + in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1) else: sequence_lengths = -1 logger.warning( diff --git a/src/transformers/models/transfo_xl/modeling_tf_transfo_xl_utilities.py b/src/transformers/models/transfo_xl/modeling_tf_transfo_xl_utilities.py index 84994f9b4..9797a8fa6 100644 --- a/src/transformers/models/transfo_xl/modeling_tf_transfo_xl_utilities.py +++ b/src/transformers/models/transfo_xl/modeling_tf_transfo_xl_utilities.py @@ -131,7 +131,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer): else: hidden_sizes = shape_list(hidden) out = [] - loss = tf.zeros(hidden_sizes[:2], dtype=tf.float32) + loss = tf.zeros(hidden_sizes[:2]) for i in range(len(self.cutoffs)): l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] if target is not None: @@ -168,7 +168,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer): cur_logprob = self._gather_logprob(cur_tail_logprob, cur_target) cur_logprob += cur_head_logprob[:, self.cutoff_ends[1] + i - 1] if target is not None: - loss += tf.scatter_nd(mask_idx, -cur_logprob, tf.cast(shape_list(loss), dtype=tf.int64)) + loss += tf.scatter_nd(mask_idx, -cur_logprob, shape_list(loss)) out = tf.concat(out, axis=-1) if target is not None: diff --git a/tests/test_modeling_tf_transfo_xl.py b/tests/test_modeling_tf_transfo_xl.py index 7232091a0..a7b6fc3d9 100644 --- a/tests/test_modeling_tf_transfo_xl.py +++ b/tests/test_modeling_tf_transfo_xl.py @@ -205,10 +205,6 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase): name = model.get_bias() assert name is None - def test_mixed_precision(self): - # TODO JP: Make TransfoXL float16 compliant - pass - def test_xla_mode(self): # TODO JP: Make TransfoXL XLA compliant pass