diff --git a/src/transformers/models/ctrl/modeling_tf_ctrl.py b/src/transformers/models/ctrl/modeling_tf_ctrl.py index b466f718d..73c76fb1a 100644 --- a/src/transformers/models/ctrl/modeling_tf_ctrl.py +++ b/src/transformers/models/ctrl/modeling_tf_ctrl.py @@ -48,7 +48,7 @@ TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = [ def angle_defn(pos, i, d_model_size): - angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model_size)) + angle_rates = 1 / np.power(10000, (2 * (i // 2)) / d_model_size) return pos * angle_rates @@ -58,9 +58,8 @@ def positional_encoding(position, d_model_size): sines = np.sin(angle_rads[:, 0::2]) cosines = np.cos(angle_rads[:, 1::2]) + pos_encoding = tf.convert_to_tensor(np.concatenate([sines, cosines], axis=-1)) - # pos_encoding = tf.cast(np.concatenate([sines, cosines], axis=-1)[np.newaxis, ...], dtype=tf.float32) - pos_encoding = tf.cast(np.concatenate([sines, cosines], axis=-1), dtype=tf.float32) return pos_encoding @@ -68,14 +67,15 @@ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=N # calculate attention matmul_qk = tf.matmul(q, k, transpose_b=True) - dk = tf.cast(shape_list(k)[-1], tf.float32) + dk = tf.cast(shape_list(k)[-1], dtype=matmul_qk.dtype) scaled_attention_logits = matmul_qk / tf.math.sqrt(dk) if mask is not None: - scaled_attention_logits += mask * -1e4 + scaled_attention_logits += tf.cast(mask * -1e4, dtype=scaled_attention_logits.dtype) if attention_mask is not None: # Apply the attention mask + attention_mask = tf.cast(attention_mask, dtype=scaled_attention_logits.dtype) scaled_attention_logits = scaled_attention_logits + attention_mask attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) @@ -332,10 +332,10 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. - inputs["attention_mask"] = tf.cast(inputs["attention_mask"], tf.float32) - inputs["attention_mask"] = (1.0 - inputs["attention_mask"]) * -10000.0 - else: - inputs["attention_mask"] = None + one_cst = tf.constant(1.0) + ten_thousand_cst = tf.constant(-10000.0) + inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=one_cst.dtype) + inputs["attention_mask"] = tf.multiply(tf.subtract(one_cst, inputs["attention_mask"]), ten_thousand_cst) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head @@ -351,9 +351,9 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): inputs["token_type_ids"], [-1, shape_list(inputs["token_type_ids"])[-1]] ) token_type_embeds = self.w(inputs["token_type_ids"], mode="embedding") - token_type_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32)) + token_type_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, dtype=token_type_embeds.dtype)) else: - token_type_embeds = 0 + token_type_embeds = tf.constant(0.0) inputs["position_ids"] = tf.reshape(inputs["position_ids"], [-1, shape_list(inputs["position_ids"])[-1]]) if inputs["inputs_embeds"] is None: @@ -361,10 +361,10 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): seq_len = input_shape[-1] mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0) - inputs["inputs_embeds"] *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32)) + inputs["inputs_embeds"] *= tf.math.sqrt(tf.cast(self.d_model_size, inputs["inputs_embeds"].dtype)) pos_embeds = tf.gather(self.pos_encoding, inputs["position_ids"]) - + pos_embeds = tf.cast(pos_embeds, dtype=token_type_embeds.dtype) hidden_states = inputs["inputs_embeds"] + pos_embeds + token_type_embeds hidden_states = self.dropout(hidden_states, training=inputs["training"]) @@ -857,7 +857,6 @@ class TFCTRLForSequenceClassification(TFCTRLPreTrainedModel, TFSequenceClassific hidden_states = transformer_outputs[0] logits = self.classifier(hidden_states) - logits_shape = shape_list(logits) in_logits = None if self.config.pad_token_id is None: sequence_lengths = -1 @@ -865,22 +864,16 @@ class TFCTRLForSequenceClassification(TFCTRLPreTrainedModel, TFSequenceClassific if inputs["input_ids"] is not None: sequence_lengths = ( tf.reduce_sum( - tf.cast(tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), tf.int32), + tf.cast( + tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), + dtype=inputs["input_ids"].dtype, + ), -1, keepdims=False, ) - 1 ) - - def get_seq_element(sequence_position, input_batch): - return tf.strided_slice( - input_batch, [sequence_position, 0], [sequence_position + 1, input_batch.shape[-1]], [1, 1] - ) - - result = tf.map_fn( - fn=lambda t: get_seq_element(t[0], t[1]), elems=[sequence_lengths, logits], dtype="float" - ) - in_logits = tf.reshape(result, [logits_shape[0], logits_shape[-1]]) + in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1) else: sequence_lengths = -1 logger.warning( diff --git a/tests/test_modeling_tf_ctrl.py b/tests/test_modeling_tf_ctrl.py index 781a7df85..e9531552b 100644 --- a/tests/test_modeling_tf_ctrl.py +++ b/tests/test_modeling_tf_ctrl.py @@ -222,14 +222,6 @@ class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase): name = model.get_bias() assert name is None - def test_mixed_precision(self): - # TODO JP: Make CTRL float16 compliant - pass - - def test_xla_mode(self): - # TODO JP: Make CTRL XLA compliant - pass - @slow def test_model_from_pretrained(self): for model_name in TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: