mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Make TF CTRL compliant with XLA and AMP (#10209)
* Fix XLA and AMP * Apply style * Remove useless cast
This commit is contained in:
parent
fdb2351ebb
commit
7246785a67
2 changed files with 18 additions and 33 deletions
|
|
@ -48,7 +48,7 @@ TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|||
|
||||
|
||||
def angle_defn(pos, i, d_model_size):
|
||||
angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model_size))
|
||||
angle_rates = 1 / np.power(10000, (2 * (i // 2)) / d_model_size)
|
||||
return pos * angle_rates
|
||||
|
||||
|
||||
|
|
@ -58,9 +58,8 @@ def positional_encoding(position, d_model_size):
|
|||
|
||||
sines = np.sin(angle_rads[:, 0::2])
|
||||
cosines = np.cos(angle_rads[:, 1::2])
|
||||
pos_encoding = tf.convert_to_tensor(np.concatenate([sines, cosines], axis=-1))
|
||||
|
||||
# pos_encoding = tf.cast(np.concatenate([sines, cosines], axis=-1)[np.newaxis, ...], dtype=tf.float32)
|
||||
pos_encoding = tf.cast(np.concatenate([sines, cosines], axis=-1), dtype=tf.float32)
|
||||
return pos_encoding
|
||||
|
||||
|
||||
|
|
@ -68,14 +67,15 @@ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=N
|
|||
# calculate attention
|
||||
matmul_qk = tf.matmul(q, k, transpose_b=True)
|
||||
|
||||
dk = tf.cast(shape_list(k)[-1], tf.float32)
|
||||
dk = tf.cast(shape_list(k)[-1], dtype=matmul_qk.dtype)
|
||||
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
|
||||
|
||||
if mask is not None:
|
||||
scaled_attention_logits += mask * -1e4
|
||||
scaled_attention_logits += tf.cast(mask * -1e4, dtype=scaled_attention_logits.dtype)
|
||||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask
|
||||
attention_mask = tf.cast(attention_mask, dtype=scaled_attention_logits.dtype)
|
||||
scaled_attention_logits = scaled_attention_logits + attention_mask
|
||||
|
||||
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
|
||||
|
|
@ -332,10 +332,10 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
|||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
|
||||
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], tf.float32)
|
||||
inputs["attention_mask"] = (1.0 - inputs["attention_mask"]) * -10000.0
|
||||
else:
|
||||
inputs["attention_mask"] = None
|
||||
one_cst = tf.constant(1.0)
|
||||
ten_thousand_cst = tf.constant(-10000.0)
|
||||
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=one_cst.dtype)
|
||||
inputs["attention_mask"] = tf.multiply(tf.subtract(one_cst, inputs["attention_mask"]), ten_thousand_cst)
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
|
|
@ -351,9 +351,9 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
|||
inputs["token_type_ids"], [-1, shape_list(inputs["token_type_ids"])[-1]]
|
||||
)
|
||||
token_type_embeds = self.w(inputs["token_type_ids"], mode="embedding")
|
||||
token_type_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32))
|
||||
token_type_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, dtype=token_type_embeds.dtype))
|
||||
else:
|
||||
token_type_embeds = 0
|
||||
token_type_embeds = tf.constant(0.0)
|
||||
inputs["position_ids"] = tf.reshape(inputs["position_ids"], [-1, shape_list(inputs["position_ids"])[-1]])
|
||||
|
||||
if inputs["inputs_embeds"] is None:
|
||||
|
|
@ -361,10 +361,10 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
|||
seq_len = input_shape[-1]
|
||||
mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
|
||||
|
||||
inputs["inputs_embeds"] *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32))
|
||||
inputs["inputs_embeds"] *= tf.math.sqrt(tf.cast(self.d_model_size, inputs["inputs_embeds"].dtype))
|
||||
|
||||
pos_embeds = tf.gather(self.pos_encoding, inputs["position_ids"])
|
||||
|
||||
pos_embeds = tf.cast(pos_embeds, dtype=token_type_embeds.dtype)
|
||||
hidden_states = inputs["inputs_embeds"] + pos_embeds + token_type_embeds
|
||||
|
||||
hidden_states = self.dropout(hidden_states, training=inputs["training"])
|
||||
|
|
@ -857,7 +857,6 @@ class TFCTRLForSequenceClassification(TFCTRLPreTrainedModel, TFSequenceClassific
|
|||
|
||||
hidden_states = transformer_outputs[0]
|
||||
logits = self.classifier(hidden_states)
|
||||
logits_shape = shape_list(logits)
|
||||
in_logits = None
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
|
|
@ -865,22 +864,16 @@ class TFCTRLForSequenceClassification(TFCTRLPreTrainedModel, TFSequenceClassific
|
|||
if inputs["input_ids"] is not None:
|
||||
sequence_lengths = (
|
||||
tf.reduce_sum(
|
||||
tf.cast(tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), tf.int32),
|
||||
tf.cast(
|
||||
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id),
|
||||
dtype=inputs["input_ids"].dtype,
|
||||
),
|
||||
-1,
|
||||
keepdims=False,
|
||||
)
|
||||
- 1
|
||||
)
|
||||
|
||||
def get_seq_element(sequence_position, input_batch):
|
||||
return tf.strided_slice(
|
||||
input_batch, [sequence_position, 0], [sequence_position + 1, input_batch.shape[-1]], [1, 1]
|
||||
)
|
||||
|
||||
result = tf.map_fn(
|
||||
fn=lambda t: get_seq_element(t[0], t[1]), elems=[sequence_lengths, logits], dtype="float"
|
||||
)
|
||||
in_logits = tf.reshape(result, [logits_shape[0], logits_shape[-1]])
|
||||
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning(
|
||||
|
|
|
|||
|
|
@ -222,14 +222,6 @@ class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
name = model.get_bias()
|
||||
assert name is None
|
||||
|
||||
def test_mixed_precision(self):
|
||||
# TODO JP: Make CTRL float16 compliant
|
||||
pass
|
||||
|
||||
def test_xla_mode(self):
|
||||
# TODO JP: Make CTRL XLA compliant
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
|
|
|
|||
Loading…
Reference in a new issue