diff --git a/src/transformers/activations_tf.py b/src/transformers/activations_tf.py new file mode 100644 index 000000000..89f445d67 --- /dev/null +++ b/src/transformers/activations_tf.py @@ -0,0 +1,65 @@ +import math + +import tensorflow as tf + + +def gelu(x): + """Gaussian Error Linear Unit. + Original Implementation of the gelu activation function in Google Bert repo when initially created. + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + Also see https://arxiv.org/abs/1606.08415 + """ + x = tf.convert_to_tensor(x) + cdf = 0.5 * (1.0 + tf.math.erf(x / tf.math.sqrt(2.0))) + + return x * cdf + + +def gelu_new(x): + """Gaussian Error Linear Unit. + This is a smoother version of the GELU. + Original paper: https://arxiv.org/abs/1606.08415 + Args: + x: float Tensor to perform activation. + Returns: + `x` with the GELU activation applied. + """ + x = tf.convert_to_tensor(x) + pi = tf.cast(math.pi, x.dtype) + coeff = tf.cast(0.044715, x.dtype) + cdf = 0.5 * (1.0 + tf.tanh(tf.sqrt(2.0 / pi) * (x + coeff * tf.pow(x, 3)))) + + return x * cdf + + +def mish(x): + x = tf.convert_to_tensor(x) + + return x * tf.tanh(tf.math.softplus(x)) + + +def gelu_fast(x): + x = tf.convert_to_tensor(x) + coeff1 = tf.cast(7978845608, x.dtype) + coeff2 = tf.cast(0.044715, x.dtype) + + return 0.5 * x * (1.0 + tf.tanh(x * coeff2 * (1.0 + coeff1 * x * x))) + + +ACT2FN = { + "gelu": tf.keras.layers.Activation(gelu), + "relu": tf.keras.activations.relu, + "swish": tf.keras.activations.swish, + "gelu_new": tf.keras.layers.Activation(gelu_new), + "mish": tf.keras.layers.Activation(mish), + "tanh": tf.keras.activations.tanh, + "gelu_fast": tf.keras.layers.Activation(gelu_fast), +} + + +def get_tf_activation(activation_string): + if activation_string in ACT2FN: + return ACT2FN[activation_string] + else: + raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys()))) diff --git a/src/transformers/modeling_tf_albert.py b/src/transformers/modeling_tf_albert.py index 86e504cf9..e6b9a9c92 100644 --- a/src/transformers/modeling_tf_albert.py +++ b/src/transformers/modeling_tf_albert.py @@ -21,6 +21,7 @@ from typing import Optional, Tuple import tensorflow as tf +from .activations_tf import get_tf_activation from .configuration_albert import AlbertConfig from .file_utils import ( MULTIPLE_CHOICE_DUMMY_INPUTS, @@ -30,7 +31,7 @@ from .file_utils import ( add_start_docstrings_to_callable, replace_return_docstrings, ) -from .modeling_tf_bert import ACT2FN, TFBertSelfAttention +from .modeling_tf_bert import TFBertSelfAttention from .modeling_tf_outputs import ( TFBaseModelOutput, TFBaseModelOutputWithPooling, @@ -354,7 +355,7 @@ class TFAlbertLayer(tf.keras.layers.Layer): ) if isinstance(config.hidden_act, str): - self.activation = ACT2FN[config.hidden_act] + self.activation = get_tf_activation(config.hidden_act) else: self.activation = config.hidden_act @@ -494,7 +495,7 @@ class TFAlbertMLMHead(tf.keras.layers.Layer): config.embedding_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" ) if isinstance(config.hidden_act, str): - self.activation = ACT2FN[config.hidden_act] + self.activation = get_tf_activation(config.hidden_act) else: self.activation = config.hidden_act diff --git a/src/transformers/modeling_tf_bert.py b/src/transformers/modeling_tf_bert.py index d3c1d9523..9cf6fbb7c 100644 --- a/src/transformers/modeling_tf_bert.py +++ b/src/transformers/modeling_tf_bert.py @@ -19,9 +19,9 @@ from dataclasses import dataclass from typing import Optional, Tuple -import numpy as np import tensorflow as tf +from .activations_tf import get_tf_activation from .configuration_bert import BertConfig from .file_utils import ( MULTIPLE_CHOICE_DUMMY_INPUTS, @@ -88,44 +88,6 @@ TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ] -def gelu(x): - """Gaussian Error Linear Unit. - Original Implementation of the gelu activation function in Google Bert repo when initially created. - For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): - 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) - Also see https://arxiv.org/abs/1606.08415 - """ - cdf = 0.5 * (1.0 + tf.math.erf(x / tf.math.sqrt(2.0))) - - return x * cdf - - -def gelu_new(x): - """Gaussian Error Linear Unit. - This is a smoother version of the RELU. - Original paper: https://arxiv.org/abs/1606.08415 - Args: - x: float Tensor to perform activation. - Returns: - `x` with the GELU activation applied. - """ - cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) - - return x * cdf - - -def swish(x): - return x * tf.sigmoid(x) - - -ACT2FN = { - "gelu": tf.keras.layers.Activation(gelu), - "relu": tf.keras.activations.relu, - "swish": tf.keras.layers.Activation(swish), - "gelu_new": tf.keras.layers.Activation(gelu_new), -} - - class TFBertEmbeddings(tf.keras.layers.Layer): """Construct the embeddings from word, position and token_type embeddings.""" @@ -352,7 +314,7 @@ class TFBertIntermediate(tf.keras.layers.Layer): ) if isinstance(config.hidden_act, str): - self.intermediate_act_fn = ACT2FN[config.hidden_act] + self.intermediate_act_fn = get_tf_activation(config.hidden_act) else: self.intermediate_act_fn = config.hidden_act @@ -467,7 +429,7 @@ class TFBertPredictionHeadTransform(tf.keras.layers.Layer): ) if isinstance(config.hidden_act, str): - self.transform_act_fn = ACT2FN[config.hidden_act] + self.transform_act_fn = get_tf_activation(config.hidden_act) else: self.transform_act_fn = config.hidden_act diff --git a/src/transformers/modeling_tf_distilbert.py b/src/transformers/modeling_tf_distilbert.py index ca8ecb4d5..4583328c3 100644 --- a/src/transformers/modeling_tf_distilbert.py +++ b/src/transformers/modeling_tf_distilbert.py @@ -18,9 +18,9 @@ import math -import numpy as np import tensorflow as tf +from .activations_tf import get_tf_activation from .configuration_distilbert import DistilBertConfig from .file_utils import ( MULTIPLE_CHOICE_DUMMY_INPUTS, @@ -68,31 +68,6 @@ TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ] -# UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE # -def gelu(x): - """Gaussian Error Linear Unit. - Original Implementation of the gelu activation function in Google Bert repo when initially created. - For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): - 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) - Also see https://arxiv.org/abs/1606.08415 - """ - cdf = 0.5 * (1.0 + tf.math.erf(x / tf.cast(tf.math.sqrt(2.0), dtype=x.dtype))) - return x * cdf - - -def gelu_new(x): - """Gaussian Error Linear Unit. - This is a smoother version of the RELU. - Original paper: https://arxiv.org/abs/1606.08415 - Args: - x: float Tensor to perform activation. - Returns: - `x` with the GELU activation applied. - """ - cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) - return x * cdf - - class TFEmbeddings(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) @@ -298,9 +273,7 @@ class TFFFN(tf.keras.layers.Layer): assert config.activation in ["relu", "gelu"], "activation ({}) must be in ['relu', 'gelu']".format( config.activation ) - self.activation = ( - tf.keras.layers.Activation(gelu) if config.activation == "gelu" else tf.keras.activations.relu - ) + self.activation = get_tf_activation(config.activation) def call(self, input, training=False): x = self.lin1(input) @@ -651,7 +624,7 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel self.vocab_transform = tf.keras.layers.Dense( config.dim, kernel_initializer=get_initializer(config.initializer_range), name="vocab_transform" ) - self.act = tf.keras.layers.Activation(gelu) + self.act = get_tf_activation("gelu") self.vocab_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="vocab_layer_norm") self.vocab_projector = TFDistilBertLMHead(config, self.distilbert.embeddings, name="vocab_projector") diff --git a/src/transformers/modeling_tf_electra.py b/src/transformers/modeling_tf_electra.py index 998fff293..9128ac13a 100644 --- a/src/transformers/modeling_tf_electra.py +++ b/src/transformers/modeling_tf_electra.py @@ -5,6 +5,7 @@ import tensorflow as tf from transformers import ElectraConfig +from .activations_tf import get_tf_activation from .file_utils import ( MULTIPLE_CHOICE_DUMMY_INPUTS, ModelOutput, @@ -13,7 +14,7 @@ from .file_utils import ( add_start_docstrings_to_callable, replace_return_docstrings, ) -from .modeling_tf_bert import ACT2FN, TFBertEncoder, TFBertPreTrainedModel +from .modeling_tf_bert import TFBertEncoder, TFBertPreTrainedModel from .modeling_tf_outputs import ( TFBaseModelOutput, TFMaskedLMOutput, @@ -173,7 +174,7 @@ class TFElectraDiscriminatorPredictions(tf.keras.layers.Layer): def call(self, discriminator_hidden_states, training=False): hidden_states = self.dense(discriminator_hidden_states) - hidden_states = ACT2FN[self.config.hidden_act](hidden_states) + hidden_states = get_tf_activation(self.config.hidden_act)(hidden_states) logits = tf.squeeze(self.dense_prediction(hidden_states)) return logits @@ -188,7 +189,7 @@ class TFElectraGeneratorPredictions(tf.keras.layers.Layer): def call(self, generator_hidden_states, training=False): hidden_states = self.dense(generator_hidden_states) - hidden_states = ACT2FN["gelu"](hidden_states) + hidden_states = get_tf_activation("gelu")(hidden_states) hidden_states = self.LayerNorm(hidden_states) return hidden_states @@ -567,7 +568,7 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos self.electra = TFElectraMainLayer(config, name="electra") self.generator_predictions = TFElectraGeneratorPredictions(config, name="generator_predictions") if isinstance(config.hidden_act, str): - self.activation = ACT2FN[config.hidden_act] + self.activation = get_tf_activation(config.hidden_act) else: self.activation = config.hidden_act self.generator_lm_head = TFElectraMaskedLMHead(config, self.electra.embeddings, name="generator_lm_head") @@ -658,7 +659,7 @@ class TFElectraClassificationHead(tf.keras.layers.Layer): x = inputs[:, 0, :] # take token (equiv. to [CLS]) x = self.dropout(x) x = self.dense(x) - x = ACT2FN["gelu"](x) # although BERT uses tanh here, it seems Electra authors used gelu here + x = get_tf_activation("gelu")(x) # although BERT uses tanh here, it seems Electra authors used gelu here x = self.dropout(x) x = self.out_proj(x) diff --git a/src/transformers/modeling_tf_funnel.py b/src/transformers/modeling_tf_funnel.py index 105a2cf84..523dd9277 100644 --- a/src/transformers/modeling_tf_funnel.py +++ b/src/transformers/modeling_tf_funnel.py @@ -19,6 +19,7 @@ from typing import Optional, Tuple import tensorflow as tf +from .activations_tf import get_tf_activation from .configuration_funnel import FunnelConfig from .file_utils import ( MULTIPLE_CHOICE_DUMMY_INPUTS, @@ -28,7 +29,6 @@ from .file_utils import ( add_start_docstrings_to_callable, replace_return_docstrings, ) -from .modeling_tf_bert import ACT2FN from .modeling_tf_outputs import ( TFBaseModelOutput, TFMaskedLMOutput, @@ -578,7 +578,7 @@ class TFFunnelPositionwiseFFN(tf.keras.layers.Layer): super().__init__(**kwargs) initializer = get_initializer(config.initializer_range) self.linear_1 = tf.keras.layers.Dense(config.d_inner, kernel_initializer=initializer, name="linear_1") - self.activation_function = ACT2FN[config.hidden_act] + self.activation_function = get_tf_activation(config.hidden_act) self.activation_dropout = tf.keras.layers.Dropout(config.activation_dropout) self.linear_2 = tf.keras.layers.Dense(config.d_model, kernel_initializer=initializer, name="linear_2") self.dropout = tf.keras.layers.Dropout(config.hidden_dropout) @@ -966,7 +966,7 @@ class TFFunnelDiscriminatorPredictions(tf.keras.layers.Layer): super().__init__(**kwargs) initializer = get_initializer(config.initializer_range) self.dense = tf.keras.layers.Dense(config.d_model, kernel_initializer=initializer, name="dense") - self.activation_function = ACT2FN[config.hidden_act] + self.activation_function = get_tf_activation(config.hidden_act) self.dense_prediction = tf.keras.layers.Dense(1, kernel_initializer=initializer, name="dense_prediction") def call(self, discriminator_hidden_states): diff --git a/src/transformers/modeling_tf_gpt2.py b/src/transformers/modeling_tf_gpt2.py index d8cb4d296..95b78a525 100644 --- a/src/transformers/modeling_tf_gpt2.py +++ b/src/transformers/modeling_tf_gpt2.py @@ -19,9 +19,9 @@ from dataclasses import dataclass from typing import List, Optional, Tuple -import numpy as np import tensorflow as tf +from .activations_tf import get_tf_activation from .configuration_gpt2 import GPT2Config from .file_utils import ( ModelOutput, @@ -60,19 +60,6 @@ TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [ ] -def gelu(x): - """Gaussian Error Linear Unit. - This is a smoother version of the RELU. - Original paper: https://arxiv.org/abs/1606.08415 - Args: - x: float Tensor to perform activation. - Returns: - `x` with the GELU activation applied. - """ - cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) - return x * cdf - - class TFAttention(tf.keras.layers.Layer): def __init__(self, nx, n_ctx, config, scale=False, **kwargs): super().__init__(**kwargs) @@ -180,7 +167,7 @@ class TFMLP(tf.keras.layers.Layer): nx = config.n_embd self.c_fc = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_fc") self.c_proj = TFConv1D(nx, n_state, initializer_range=config.initializer_range, name="c_proj") - self.act = gelu + self.act = get_tf_activation("gelu") self.dropout = tf.keras.layers.Dropout(config.resid_pdrop) def call(self, x, training=False): diff --git a/src/transformers/modeling_tf_lxmert.py b/src/transformers/modeling_tf_lxmert.py index c034af1e6..ce96da0e0 100644 --- a/src/transformers/modeling_tf_lxmert.py +++ b/src/transformers/modeling_tf_lxmert.py @@ -21,11 +21,11 @@ import logging from dataclasses import dataclass from typing import Dict, Optional, Tuple -import numpy as np import tensorflow as tf from transformers import BatchEncoding +from .activations_tf import get_tf_activation from .configuration_lxmert import LxmertConfig from .file_utils import ( ModelOutput, @@ -48,42 +48,6 @@ TF_LXMERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ] -def gelu(x): - """Gaussian Error Linear Unit. - Original Implementation of the gelu activation function in Google Bert repo when initially created. - For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): - 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) - Also see https://arxiv.org/abs/1606.08415 - """ - cdf = 0.5 * (1.0 + tf.math.erf(x / tf.math.sqrt(2.0))) - return x * cdf - - -def gelu_new(x): - """Gaussian Error Linear Unit. - This is a smoother version of the RELU. - Original paper: https://arxiv.org/abs/1606.08415 - Args: - x: float Tensor to perform activation. - Returns: - `x` with the GELU activation applied. - """ - cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) - return x * cdf - - -def swish(x): - return x * tf.sigmoid(x) - - -ACT2FN = { - "gelu": tf.keras.layers.Activation(gelu), - "relu": tf.keras.activations.relu, - "swish": tf.keras.layers.Activation(swish), - "gelu_new": tf.keras.layers.Activation(gelu_new), -} - - @dataclass class TFLxmertModelOutput(ModelOutput): """ @@ -404,7 +368,7 @@ class TFLxmertIntermediate(tf.keras.layers.Layer): name="dense", ) if isinstance(config.hidden_act, str): - self.intermediate_act_fn = ACT2FN[config.hidden_act] + self.intermediate_act_fn = get_tf_activation(config.hidden_act) else: self.intermediate_act_fn = config.hidden_act @@ -1012,7 +976,7 @@ class TFLxmertPredictionHeadTransform(tf.keras.layers.Layer): name="dense", ) if isinstance(config.hidden_act, str): - self.transform_act_fn = ACT2FN[config.hidden_act] + self.transform_act_fn = get_tf_activation(config.hidden_act) else: self.transform_act_fn = config.hidden_act self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") @@ -1082,7 +1046,7 @@ class TFLxmertVisualAnswerHead(tf.keras.layers.Layer): kernel_initializer=get_initializer(config.initializer_range), name="logit_fc_._0", ) - self.activation = tf.keras.layers.Activation(gelu) + self.activation = get_tf_activation("gelu") self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="logit_fc_._2") self.dense_1 = tf.keras.layers.Dense( num_labels, diff --git a/src/transformers/modeling_tf_mobilebert.py b/src/transformers/modeling_tf_mobilebert.py index f188108b3..f620e3add 100644 --- a/src/transformers/modeling_tf_mobilebert.py +++ b/src/transformers/modeling_tf_mobilebert.py @@ -22,6 +22,7 @@ from typing import Optional, Tuple import tensorflow as tf from . import MobileBertConfig +from .activations_tf import get_tf_activation from .file_utils import ( MULTIPLE_CHOICE_DUMMY_INPUTS, ModelOutput, @@ -30,7 +31,7 @@ from .file_utils import ( add_start_docstrings_to_callable, replace_return_docstrings, ) -from .modeling_tf_bert import TFBertIntermediate, gelu, gelu_new, swish +from .modeling_tf_bert import TFBertIntermediate from .modeling_tf_outputs import ( TFBaseModelOutput, TFBaseModelOutputWithPooling, @@ -67,10 +68,6 @@ TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ] -def mish(x): - return x * tf.tanh(tf.math.softplus(x)) - - class TFLayerNorm(tf.keras.layers.LayerNormalization): def __init__(self, feat_size, *args, **kwargs): super().__init__(*args, **kwargs) @@ -89,12 +86,6 @@ class TFNoNorm(tf.keras.layers.Layer): return inputs * self.weight + self.bias -ACT2FN = { - "gelu": tf.keras.layers.Activation(gelu), - "relu": tf.keras.activations.relu, - "swish": tf.keras.layers.Activation(swish), - "gelu_new": tf.keras.layers.Activation(gelu_new), -} NORM2FN = {"layer_norm": TFLayerNorm, "no_norm": TFNoNorm} @@ -621,7 +612,7 @@ class TFMobileBertPredictionHeadTransform(tf.keras.layers.Layer): config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" ) if isinstance(config.hidden_act, str): - self.transform_act_fn = ACT2FN[config.hidden_act] + self.transform_act_fn = get_tf_activation(config.hidden_act) else: self.transform_act_fn = config.hidden_act self.LayerNorm = NORM2FN["layer_norm"](config.hidden_size, epsilon=config.layer_norm_eps, name="LayerNorm") diff --git a/src/transformers/modeling_tf_openai.py b/src/transformers/modeling_tf_openai.py index 14ad49c93..025d02197 100644 --- a/src/transformers/modeling_tf_openai.py +++ b/src/transformers/modeling_tf_openai.py @@ -19,9 +19,9 @@ from dataclasses import dataclass from typing import Optional, Tuple -import numpy as np import tensorflow as tf +from .activations_tf import get_tf_activation from .configuration_openai import OpenAIGPTConfig from .file_utils import ( ModelOutput, @@ -56,30 +56,6 @@ TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ] -def gelu(x): - """Gaussian Error Linear Unit. - This is a smoother version of the RELU. - Original paper: https://arxiv.org/abs/1606.08415 - Args: - x: float Tensor to perform activation. - Returns: - `x` with the GELU activation applied. - """ - cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) - return x * cdf - - -def swish(x): - return x * tf.math.sigmoid(x) - - -ACT_FNS = { - "gelu": tf.keras.layers.Activation(gelu), - "relu": tf.keras.activations.relu, - "swish": tf.keras.layers.Activation(swish), -} - - class TFAttention(tf.keras.layers.Layer): def __init__(self, nx, n_ctx, config, scale=False, **kwargs): super().__init__(**kwargs) @@ -179,7 +155,7 @@ class TFMLP(tf.keras.layers.Layer): nx = config.n_embd self.c_fc = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_fc") self.c_proj = TFConv1D(nx, n_state, initializer_range=config.initializer_range, name="c_proj") - self.act = gelu + self.act = get_tf_activation("gelu") self.dropout = tf.keras.layers.Dropout(config.resid_pdrop) def call(self, x, training=False): diff --git a/src/transformers/modeling_tf_roberta.py b/src/transformers/modeling_tf_roberta.py index e964fab7c..e0f35aa88 100644 --- a/src/transformers/modeling_tf_roberta.py +++ b/src/transformers/modeling_tf_roberta.py @@ -18,6 +18,7 @@ import tensorflow as tf +from .activations_tf import get_tf_activation from .configuration_roberta import RobertaConfig from .file_utils import ( MULTIPLE_CHOICE_DUMMY_INPUTS, @@ -25,7 +26,7 @@ from .file_utils import ( add_start_docstrings, add_start_docstrings_to_callable, ) -from .modeling_tf_bert import TFBertEmbeddings, TFBertMainLayer, gelu +from .modeling_tf_bert import TFBertEmbeddings, TFBertMainLayer from .modeling_tf_outputs import ( TFBaseModelOutputWithPooling, TFMaskedLMOutput, @@ -237,7 +238,7 @@ class TFRobertaLMHead(tf.keras.layers.Layer): config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" ) self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") - self.act = tf.keras.layers.Activation(gelu) + self.act = get_tf_activation("gelu") # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. diff --git a/src/transformers/modeling_tf_xlm.py b/src/transformers/modeling_tf_xlm.py index b504a6f13..2e1d10e6b 100644 --- a/src/transformers/modeling_tf_xlm.py +++ b/src/transformers/modeling_tf_xlm.py @@ -25,6 +25,7 @@ from typing import Optional, Tuple import numpy as np import tensorflow as tf +from .activations_tf import get_tf_activation from .configuration_xlm import XLMConfig from .file_utils import ( MULTIPLE_CHOICE_DUMMY_INPUTS, @@ -82,17 +83,6 @@ def create_sinusoidal_embeddings(n_pos, dim, out): out[:, 1::2] = tf.constant(np.cos(position_enc[:, 1::2])) -def gelu(x): - """Gaussian Error Linear Unit. - Original Implementation of the gelu activation function in Google Bert repo when initially created. - For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): - 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) - Also see https://arxiv.org/abs/1606.08415 - """ - cdf = 0.5 * (1.0 + tf.math.erf(x / tf.math.sqrt(2.0))) - return x * cdf - - def get_masks(slen, lengths, causal, padding_mask=None, dtype=tf.float32): """ Generate hidden states mask, and optionally an attention mask. @@ -216,7 +206,7 @@ class TFTransformerFFN(tf.keras.layers.Layer): super().__init__(**kwargs) self.lin1 = tf.keras.layers.Dense(dim_hidden, kernel_initializer=get_initializer(config.init_std), name="lin1") self.lin2 = tf.keras.layers.Dense(out_dim, kernel_initializer=get_initializer(config.init_std), name="lin2") - self.act = tf.keras.layers.Activation(gelu) if config.gelu_activation else tf.keras.activations.relu + self.act = get_tf_activation("gelu") if config.gelu_activation else get_tf_activation("relu") self.dropout = tf.keras.layers.Dropout(config.dropout) def call(self, input, training=False): diff --git a/src/transformers/modeling_tf_xlnet.py b/src/transformers/modeling_tf_xlnet.py index 50e3232d8..1bdbee870 100644 --- a/src/transformers/modeling_tf_xlnet.py +++ b/src/transformers/modeling_tf_xlnet.py @@ -20,9 +20,9 @@ from dataclasses import dataclass from typing import List, Optional, Tuple -import numpy as np import tensorflow as tf +from .activations_tf import get_tf_activation from .configuration_xlnet import XLNetConfig from .file_utils import ( MULTIPLE_CHOICE_DUMMY_INPUTS, @@ -61,26 +61,6 @@ TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST = [ ] -def gelu(x): - """Implementation of the gelu activation function. - XLNet is using OpenAI GPT's gelu - Also see https://arxiv.org/abs/1606.08415 - """ - cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) - return x * cdf - - -def swish(x): - return x * tf.sigmoid(x) - - -ACT2FN = { - "gelu": tf.keras.layers.Activation(gelu), - "relu": tf.keras.activations.relu, - "swish": tf.keras.layers.Activation(swish), -} - - class TFXLNetRelativeAttention(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) @@ -356,7 +336,7 @@ class TFXLNetFeedForward(tf.keras.layers.Layer): ) self.dropout = tf.keras.layers.Dropout(config.dropout) if isinstance(config.ff_activation, str): - self.activation_function = ACT2FN[config.ff_activation] + self.activation_function = get_tf_activation(config.ff_activation) else: self.activation_function = config.ff_activation diff --git a/tests/test_activations_tf.py b/tests/test_activations_tf.py new file mode 100644 index 000000000..bdaecff40 --- /dev/null +++ b/tests/test_activations_tf.py @@ -0,0 +1,24 @@ +import unittest + +from transformers import is_tf_available +from transformers.testing_utils import require_tf + + +if is_tf_available(): + from transformers.activations_tf import get_tf_activation + + +@require_tf +class TestTFActivations(unittest.TestCase): + def test_get_activation(self): + get_tf_activation("swish") + get_tf_activation("gelu") + get_tf_activation("relu") + get_tf_activation("tanh") + get_tf_activation("gelu_new") + get_tf_activation("gelu_fast") + get_tf_activation("mish") + with self.assertRaises(KeyError): + get_tf_activation("bogus") + with self.assertRaises(KeyError): + get_tf_activation(None)