mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Making TF OpenAI GPT model compliant with AMP and XLA (#10261)
* Fix AMP and XLA * Remove useless var
This commit is contained in:
parent
3e116ed331
commit
34df26ec3a
2 changed files with 29 additions and 34 deletions
|
|
@ -81,7 +81,7 @@ class TFAttention(tf.keras.layers.Layer):
|
|||
pass
|
||||
|
||||
@staticmethod
|
||||
def causal_attention_mask(nd, ns, dtype):
|
||||
def causal_attention_mask(nd, ns):
|
||||
"""
|
||||
1's in the lower triangle, counting from the lower right corner. Same as tf.matrix_band_part(tf.ones([nd, ns]),
|
||||
-1, ns-nd), but doesn't produce garbage on TPUs.
|
||||
|
|
@ -89,23 +89,24 @@ class TFAttention(tf.keras.layers.Layer):
|
|||
i = tf.range(nd)[:, None]
|
||||
j = tf.range(ns)
|
||||
m = i >= j - ns + nd
|
||||
return tf.cast(m, dtype)
|
||||
return m
|
||||
|
||||
def _attn(self, q, k, v, attention_mask, head_mask, output_attentions, training=False):
|
||||
# q, k, v have shape [batch, heads, sequence, features]
|
||||
w = tf.matmul(q, k, transpose_b=True)
|
||||
if self.scale:
|
||||
dk = tf.cast(shape_list(k)[-1], tf.float32) # scale attention_scores
|
||||
dk = tf.cast(shape_list(k)[-1], dtype=w.dtype) # scale attention_scores
|
||||
w = w / tf.math.sqrt(dk)
|
||||
|
||||
# w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
|
||||
_, _, nd, ns = shape_list(w)
|
||||
b = self.causal_attention_mask(nd, ns, dtype=w.dtype)
|
||||
b = tf.cast(self.causal_attention_mask(nd, ns), dtype=w.dtype)
|
||||
b = tf.reshape(b, [1, 1, nd, ns])
|
||||
w = w * b - 1e4 * (1 - b)
|
||||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask
|
||||
attention_mask = tf.cast(attention_mask, dtype=w.dtype)
|
||||
w = w + attention_mask
|
||||
|
||||
w = tf.nn.softmax(w, axis=-1)
|
||||
|
|
@ -201,19 +202,25 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
|
|||
self.num_hidden_layers = config.n_layer
|
||||
self.vocab_size = config.vocab_size
|
||||
self.n_embd = config.n_embd
|
||||
self.n_positions = config.n_positions
|
||||
self.initializer_range = config.initializer_range
|
||||
|
||||
self.tokens_embed = TFSharedEmbeddings(
|
||||
config.vocab_size, config.n_embd, initializer_range=config.initializer_range, name="tokens_embed"
|
||||
)
|
||||
self.positions_embed = tf.keras.layers.Embedding(
|
||||
config.n_positions,
|
||||
config.n_embd,
|
||||
embeddings_initializer=get_initializer(config.initializer_range),
|
||||
name="positions_embed",
|
||||
)
|
||||
self.drop = tf.keras.layers.Dropout(config.embd_pdrop)
|
||||
self.h = [TFBlock(config.n_ctx, config, scale=True, name="h_._{}".format(i)) for i in range(config.n_layer)]
|
||||
|
||||
def build(self, input_shape):
|
||||
with tf.name_scope("positions_embed"):
|
||||
self.positions_embed = self.add_weight(
|
||||
name="embeddings",
|
||||
shape=[self.n_positions, self.n_embd],
|
||||
initializer=get_initializer(self.initializer_range),
|
||||
)
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.tokens_embed
|
||||
|
||||
|
|
@ -268,7 +275,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
|
|||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs["position_ids"] is None:
|
||||
inputs["position_ids"] = tf.expand_dims(tf.range(input_shape[-1], dtype=tf.int32), axis=0)
|
||||
inputs["position_ids"] = tf.expand_dims(tf.range(input_shape[-1]), axis=0)
|
||||
|
||||
if inputs["attention_mask"] is not None:
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
|
|
@ -284,8 +291,11 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
|
|||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
|
||||
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], tf.float32)
|
||||
inputs["attention_mask"] = (1.0 - inputs["attention_mask"]) * -10000.0
|
||||
one_cst = tf.constant(1.0)
|
||||
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=one_cst.dtype)
|
||||
inputs["attention_mask"] = tf.multiply(
|
||||
tf.subtract(one_cst, inputs["attention_mask"]), tf.constant(-10000.0)
|
||||
)
|
||||
else:
|
||||
inputs["attention_mask"] = None
|
||||
|
||||
|
|
@ -304,7 +314,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
|
|||
|
||||
if inputs["inputs_embeds"] is None:
|
||||
inputs["inputs_embeds"] = self.tokens_embed(inputs["input_ids"], mode="embedding")
|
||||
position_embeds = self.positions_embed(inputs["position_ids"])
|
||||
position_embeds = tf.gather(self.positions_embed, inputs["position_ids"])
|
||||
if inputs["token_type_ids"] is not None:
|
||||
inputs["token_type_ids"] = tf.reshape(
|
||||
inputs["token_type_ids"], [-1, shape_list(inputs["token_type_ids"])[-1]]
|
||||
|
|
@ -903,7 +913,6 @@ class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenc
|
|||
|
||||
hidden_states = transformer_outputs[0]
|
||||
logits = self.score(hidden_states)
|
||||
logits_shape = shape_list(logits)
|
||||
in_logits = None
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
|
|
@ -911,22 +920,16 @@ class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenc
|
|||
if inputs["input_ids"] is not None:
|
||||
sequence_lengths = (
|
||||
tf.reduce_sum(
|
||||
tf.cast(tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), tf.int32),
|
||||
tf.cast(
|
||||
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id),
|
||||
dtype=inputs["input_ids"].dtype,
|
||||
),
|
||||
-1,
|
||||
keepdims=False,
|
||||
)
|
||||
- 1
|
||||
)
|
||||
|
||||
def get_seq_element(sequence_position, input_batch):
|
||||
return tf.strided_slice(
|
||||
input_batch, [sequence_position, 0], [sequence_position + 1, input_batch.shape[-1]], [1, 1]
|
||||
)
|
||||
|
||||
result = tf.map_fn(
|
||||
fn=lambda t: get_seq_element(t[0], t[1]), elems=[sequence_lengths, logits], dtype="float"
|
||||
)
|
||||
in_logits = tf.reshape(result, [logits_shape[0], logits_shape[-1]])
|
||||
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning(
|
||||
|
|
|
|||
|
|
@ -246,14 +246,6 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_openai_gpt_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
def test_mixed_precision(self):
|
||||
# TODO JP: Make OpenAIGPT float16 compliant
|
||||
pass
|
||||
|
||||
def test_xla_mode(self):
|
||||
# TODO JP: Make OpenAIGPT XLA compliant
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
|
|
|
|||
Loading…
Reference in a new issue