mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Making TF TransfoXL model compliant with AMP (#10264)
* Fix AMP * Apply style * Remove unused import
This commit is contained in:
parent
86caeb7636
commit
3e116ed331
3 changed files with 41 additions and 24 deletions
|
|
@ -59,6 +59,7 @@ class TFPositionalEmbedding(tf.keras.layers.Layer):
|
|||
self.inv_freq = 1 / (10000 ** (tf.range(0, demb, 2.0) / demb))
|
||||
|
||||
def call(self, pos_seq, bsz=None):
|
||||
self.inv_freq = tf.cast(self.inv_freq, dtype=pos_seq.dtype)
|
||||
sinusoid_inp = tf.einsum("i,j->ij", pos_seq, self.inv_freq)
|
||||
pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1)
|
||||
|
||||
|
|
@ -186,6 +187,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
|
|||
qlen, rlen, bsz = shape_list(w)[0], shape_list(r)[0], shape_list(w)[1]
|
||||
|
||||
if mems is not None:
|
||||
mems = tf.cast(mems, dtype=w.dtype)
|
||||
cat = tf.concat([mems, w], 0)
|
||||
if self.pre_lnorm:
|
||||
w_heads = self.qkv_net(self.layer_norm(cat))
|
||||
|
|
@ -227,7 +229,8 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
|
|||
# compute attention probability
|
||||
if attn_mask is not None:
|
||||
attn_mask_t = attn_mask[:, :, None, None]
|
||||
attn_score = attn_score * (1 - attn_mask_t) - 1e30 * attn_mask_t
|
||||
attn_mask_t = tf.cast(attn_mask_t, dtype=attn_score.dtype)
|
||||
attn_score = attn_score * (1.0 - attn_mask_t) - 1e30 * attn_mask_t
|
||||
|
||||
# [qlen x klen x bsz x n_head]
|
||||
attn_prob = tf.nn.softmax(attn_score, axis=1)
|
||||
|
|
@ -313,6 +316,27 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
|
|||
return outputs
|
||||
|
||||
|
||||
class TFTransfoEmbeddings(tf.keras.layers.Layer):
|
||||
def __init__(self, vocab_size, emb_size, init_std, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.emb_size = emb_size
|
||||
self.init_std = init_std
|
||||
|
||||
def build(self, input_shape):
|
||||
self.weight = self.add_weight(
|
||||
shape=(self.vocab_size, self.emb_size),
|
||||
initializer=get_initializer(self.init_std),
|
||||
name="embeddings",
|
||||
)
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
def call(self, inputs):
|
||||
return tf.gather(self.weight, inputs)
|
||||
|
||||
|
||||
class TFAdaptiveEmbedding(tf.keras.layers.Layer):
|
||||
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, init_std=0.02, sample_softmax=False, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
|
@ -331,6 +355,7 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
|
|||
|
||||
self.emb_layers = []
|
||||
self.emb_projs = []
|
||||
|
||||
if div_val == 1:
|
||||
raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
|
||||
else:
|
||||
|
|
@ -338,10 +363,10 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
|
|||
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
|
||||
d_emb_i = d_embed // (div_val ** i)
|
||||
self.emb_layers.append(
|
||||
tf.keras.layers.Embedding(
|
||||
TFTransfoEmbeddings(
|
||||
r_idx - l_idx,
|
||||
d_emb_i,
|
||||
embeddings_initializer=get_initializer(init_std),
|
||||
init_std,
|
||||
name="emb_layers_._{}".format(i),
|
||||
)
|
||||
)
|
||||
|
|
@ -357,6 +382,7 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
|
|||
name="emb_projs_._{}".format(i),
|
||||
)
|
||||
)
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
def call(self, inp):
|
||||
|
|
@ -374,8 +400,10 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
|
|||
emb_i = self.emb_layers[i](inp_i)
|
||||
emb_i = tf.einsum("id,de->ie", emb_i, self.emb_projs[i])
|
||||
|
||||
mask_idx = tf.cast(tf.where(mask_i), dtype=tf.int64)
|
||||
emb_flat += tf.scatter_nd(mask_idx, emb_i, tf.cast(shape_list(emb_flat), dtype=tf.int64))
|
||||
mask_idx = tf.where(mask_i)
|
||||
scatter = tf.scatter_nd(mask_idx, emb_i, shape_list(emb_flat))
|
||||
emb_flat = tf.cast(emb_flat, dtype=scatter.dtype)
|
||||
emb_flat += scatter
|
||||
|
||||
embed_shape = shape_list(inp) + [self.d_proj]
|
||||
embed = tf.reshape(emb_flat, embed_shape)
|
||||
|
|
@ -501,7 +529,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
|||
end_idx = mlen + tf.math.maximum(0, qlen)
|
||||
beg_idx = tf.math.maximum(0, end_idx - tf.convert_to_tensor(self.mem_len))
|
||||
for i in range(len(hids)):
|
||||
|
||||
mems[i] = tf.cast(mems[i], dtype=hids[i].dtype)
|
||||
cat = tf.concat([mems[i], hids[i]], axis=0)
|
||||
tf.stop_gradient(cat)
|
||||
new_mems.append(cat[beg_idx:end_idx])
|
||||
|
|
@ -1113,7 +1141,6 @@ class TFTransfoXLForSequenceClassification(TFTransfoXLPreTrainedModel, TFSequenc
|
|||
|
||||
hidden_states = transformer_outputs[0]
|
||||
logits = self.score(hidden_states)
|
||||
logits_shape = shape_list(logits)
|
||||
in_logits = None
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
|
|
@ -1121,22 +1148,16 @@ class TFTransfoXLForSequenceClassification(TFTransfoXLPreTrainedModel, TFSequenc
|
|||
if inputs["input_ids"] is not None:
|
||||
sequence_lengths = (
|
||||
tf.reduce_sum(
|
||||
tf.cast(tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), tf.int32),
|
||||
tf.cast(
|
||||
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id),
|
||||
dtype=inputs["input_ids"].dtype,
|
||||
),
|
||||
-1,
|
||||
keepdims=False,
|
||||
)
|
||||
- 1
|
||||
)
|
||||
|
||||
def get_seq_element(sequence_position, input_batch):
|
||||
return tf.strided_slice(
|
||||
input_batch, [sequence_position, 0], [sequence_position + 1, input_batch.shape[-1]], [1, 1]
|
||||
)
|
||||
|
||||
result = tf.map_fn(
|
||||
fn=lambda t: get_seq_element(t[0], t[1]), elems=[sequence_lengths, logits], dtype="float"
|
||||
)
|
||||
in_logits = tf.reshape(result, [logits_shape[0], logits_shape[-1]])
|
||||
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning(
|
||||
|
|
|
|||
|
|
@ -131,7 +131,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
|
|||
else:
|
||||
hidden_sizes = shape_list(hidden)
|
||||
out = []
|
||||
loss = tf.zeros(hidden_sizes[:2], dtype=tf.float32)
|
||||
loss = tf.zeros(hidden_sizes[:2])
|
||||
for i in range(len(self.cutoffs)):
|
||||
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
|
||||
if target is not None:
|
||||
|
|
@ -168,7 +168,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
|
|||
cur_logprob = self._gather_logprob(cur_tail_logprob, cur_target)
|
||||
cur_logprob += cur_head_logprob[:, self.cutoff_ends[1] + i - 1]
|
||||
if target is not None:
|
||||
loss += tf.scatter_nd(mask_idx, -cur_logprob, tf.cast(shape_list(loss), dtype=tf.int64))
|
||||
loss += tf.scatter_nd(mask_idx, -cur_logprob, shape_list(loss))
|
||||
out = tf.concat(out, axis=-1)
|
||||
|
||||
if target is not None:
|
||||
|
|
|
|||
|
|
@ -205,10 +205,6 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
name = model.get_bias()
|
||||
assert name is None
|
||||
|
||||
def test_mixed_precision(self):
|
||||
# TODO JP: Make TransfoXL float16 compliant
|
||||
pass
|
||||
|
||||
def test_xla_mode(self):
|
||||
# TODO JP: Make TransfoXL XLA compliant
|
||||
pass
|
||||
|
|
|
|||
Loading…
Reference in a new issue