From 34df26ec3a2212419b596fa80344fdae79fed4f6 Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Fri, 19 Feb 2021 15:33:25 +0100 Subject: [PATCH] Making TF OpenAI GPT model compliant with AMP and XLA (#10261) * Fix AMP and XLA * Remove useless var --- .../models/openai/modeling_tf_openai.py | 55 ++++++++++--------- tests/test_modeling_tf_openai.py | 8 --- 2 files changed, 29 insertions(+), 34 deletions(-) diff --git a/src/transformers/models/openai/modeling_tf_openai.py b/src/transformers/models/openai/modeling_tf_openai.py index 5c586fa09..0ce16b670 100644 --- a/src/transformers/models/openai/modeling_tf_openai.py +++ b/src/transformers/models/openai/modeling_tf_openai.py @@ -81,7 +81,7 @@ class TFAttention(tf.keras.layers.Layer): pass @staticmethod - def causal_attention_mask(nd, ns, dtype): + def causal_attention_mask(nd, ns): """ 1's in the lower triangle, counting from the lower right corner. Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs. @@ -89,23 +89,24 @@ class TFAttention(tf.keras.layers.Layer): i = tf.range(nd)[:, None] j = tf.range(ns) m = i >= j - ns + nd - return tf.cast(m, dtype) + return m def _attn(self, q, k, v, attention_mask, head_mask, output_attentions, training=False): # q, k, v have shape [batch, heads, sequence, features] w = tf.matmul(q, k, transpose_b=True) if self.scale: - dk = tf.cast(shape_list(k)[-1], tf.float32) # scale attention_scores + dk = tf.cast(shape_list(k)[-1], dtype=w.dtype) # scale attention_scores w = w / tf.math.sqrt(dk) # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst. _, _, nd, ns = shape_list(w) - b = self.causal_attention_mask(nd, ns, dtype=w.dtype) + b = tf.cast(self.causal_attention_mask(nd, ns), dtype=w.dtype) b = tf.reshape(b, [1, 1, nd, ns]) w = w * b - 1e4 * (1 - b) if attention_mask is not None: # Apply the attention mask + attention_mask = tf.cast(attention_mask, dtype=w.dtype) w = w + attention_mask w = tf.nn.softmax(w, axis=-1) @@ -201,19 +202,25 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer): self.num_hidden_layers = config.n_layer self.vocab_size = config.vocab_size self.n_embd = config.n_embd + self.n_positions = config.n_positions + self.initializer_range = config.initializer_range self.tokens_embed = TFSharedEmbeddings( config.vocab_size, config.n_embd, initializer_range=config.initializer_range, name="tokens_embed" ) - self.positions_embed = tf.keras.layers.Embedding( - config.n_positions, - config.n_embd, - embeddings_initializer=get_initializer(config.initializer_range), - name="positions_embed", - ) self.drop = tf.keras.layers.Dropout(config.embd_pdrop) self.h = [TFBlock(config.n_ctx, config, scale=True, name="h_._{}".format(i)) for i in range(config.n_layer)] + def build(self, input_shape): + with tf.name_scope("positions_embed"): + self.positions_embed = self.add_weight( + name="embeddings", + shape=[self.n_positions, self.n_embd], + initializer=get_initializer(self.initializer_range), + ) + + super().build(input_shape) + def get_input_embeddings(self): return self.tokens_embed @@ -268,7 +275,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer): raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs["position_ids"] is None: - inputs["position_ids"] = tf.expand_dims(tf.range(input_shape[-1], dtype=tf.int32), axis=0) + inputs["position_ids"] = tf.expand_dims(tf.range(input_shape[-1]), axis=0) if inputs["attention_mask"] is not None: # We create a 3D attention mask from a 2D tensor mask. @@ -284,8 +291,11 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer): # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. - inputs["attention_mask"] = tf.cast(inputs["attention_mask"], tf.float32) - inputs["attention_mask"] = (1.0 - inputs["attention_mask"]) * -10000.0 + one_cst = tf.constant(1.0) + inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=one_cst.dtype) + inputs["attention_mask"] = tf.multiply( + tf.subtract(one_cst, inputs["attention_mask"]), tf.constant(-10000.0) + ) else: inputs["attention_mask"] = None @@ -304,7 +314,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer): if inputs["inputs_embeds"] is None: inputs["inputs_embeds"] = self.tokens_embed(inputs["input_ids"], mode="embedding") - position_embeds = self.positions_embed(inputs["position_ids"]) + position_embeds = tf.gather(self.positions_embed, inputs["position_ids"]) if inputs["token_type_ids"] is not None: inputs["token_type_ids"] = tf.reshape( inputs["token_type_ids"], [-1, shape_list(inputs["token_type_ids"])[-1]] @@ -903,7 +913,6 @@ class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenc hidden_states = transformer_outputs[0] logits = self.score(hidden_states) - logits_shape = shape_list(logits) in_logits = None if self.config.pad_token_id is None: sequence_lengths = -1 @@ -911,22 +920,16 @@ class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenc if inputs["input_ids"] is not None: sequence_lengths = ( tf.reduce_sum( - tf.cast(tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), tf.int32), + tf.cast( + tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), + dtype=inputs["input_ids"].dtype, + ), -1, keepdims=False, ) - 1 ) - - def get_seq_element(sequence_position, input_batch): - return tf.strided_slice( - input_batch, [sequence_position, 0], [sequence_position + 1, input_batch.shape[-1]], [1, 1] - ) - - result = tf.map_fn( - fn=lambda t: get_seq_element(t[0], t[1]), elems=[sequence_lengths, logits], dtype="float" - ) - in_logits = tf.reshape(result, [logits_shape[0], logits_shape[-1]]) + in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1) else: sequence_lengths = -1 logger.warning( diff --git a/tests/test_modeling_tf_openai.py b/tests/test_modeling_tf_openai.py index 87e105843..4dc684adb 100644 --- a/tests/test_modeling_tf_openai.py +++ b/tests/test_modeling_tf_openai.py @@ -246,14 +246,6 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_openai_gpt_for_sequence_classification(*config_and_inputs) - def test_mixed_precision(self): - # TODO JP: Make OpenAIGPT float16 compliant - pass - - def test_xla_mode(self): - # TODO JP: Make OpenAIGPT XLA compliant - pass - @slow def test_model_from_pretrained(self): for model_name in TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: