mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Flax T5 (#12150)
* copy pytorch-t5 * init * boom boom * forward pass same * make generation work * add more tests * make test work * finish normal tests * make fix-copies * finish quality * correct slow example * correct slow test * version table * upload models * Update tests/test_modeling_flax_t5.py * correct incorrectly deleted line Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
parent
7d4cfa3b47
commit
e98233dde1
13 changed files with 2180 additions and 7 deletions
|
|
@ -396,7 +396,7 @@ Flax), PyTorch, and/or TensorFlow.
|
|||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| SqueezeBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| T5 | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| T5 | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| TAPAS | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
|
|
|
|||
|
|
@ -160,3 +160,15 @@ TFT5EncoderModel
|
|||
|
||||
.. autoclass:: transformers.TFT5EncoderModel
|
||||
:members: call
|
||||
|
||||
FlaxT5Model
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxT5Model
|
||||
:members: __call__, encode, decode
|
||||
|
||||
FlaxT5ForConditionalGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxT5ForConditionalGeneration
|
||||
:members: __call__, encode, decode
|
||||
|
|
|
|||
5
setup.py
5
setup.py
|
|
@ -114,6 +114,7 @@ _deps = [
|
|||
"onnxruntime-tools>=1.4.2",
|
||||
"onnxruntime>=1.4.0",
|
||||
"optuna",
|
||||
"optax>=0.0.8",
|
||||
"packaging",
|
||||
"parameterized",
|
||||
"protobuf",
|
||||
|
|
@ -234,7 +235,7 @@ if os.name == "nt": # windows
|
|||
extras["flax"] = [] # jax is not supported on windows
|
||||
else:
|
||||
extras["retrieval"] = deps_list("faiss-cpu", "datasets")
|
||||
extras["flax"] = deps_list("jax", "jaxlib", "flax")
|
||||
extras["flax"] = deps_list("jax", "jaxlib", "flax", "optax")
|
||||
|
||||
extras["tokenizers"] = deps_list("tokenizers")
|
||||
extras["onnxruntime"] = deps_list("onnxruntime", "onnxruntime-tools")
|
||||
|
|
@ -325,7 +326,7 @@ install_requires = [
|
|||
deps["huggingface-hub"],
|
||||
deps["numpy"],
|
||||
deps["packaging"], # utilities from PyPA to e.g., compare versions
|
||||
deps["pyyaml"], # used for the model cards metadata
|
||||
deps["pyyaml"], # used for the model cards metadata
|
||||
deps["regex"], # for OpenAI GPT
|
||||
deps["requests"], # for downloading models over HTTPS
|
||||
deps["sacremoses"], # for XLM
|
||||
|
|
|
|||
|
|
@ -1597,6 +1597,7 @@ if is_flax_available():
|
|||
"FlaxRobertaPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model"])
|
||||
_import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel"])
|
||||
else:
|
||||
from .utils import dummy_flax_objects
|
||||
|
|
@ -2920,6 +2921,7 @@ if TYPE_CHECKING:
|
|||
FlaxRobertaModel,
|
||||
FlaxRobertaPreTrainedModel,
|
||||
)
|
||||
from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model
|
||||
from .models.vit import FlaxViTForImageClassification, FlaxViTModel
|
||||
else:
|
||||
# Import the same objects as dummies to get them in the namespace.
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ deps = {
|
|||
"onnxruntime-tools": "onnxruntime-tools>=1.4.2",
|
||||
"onnxruntime": "onnxruntime>=1.4.0",
|
||||
"optuna": "optuna",
|
||||
"optax": "optax>=0.0.8",
|
||||
"packaging": "packaging",
|
||||
"parameterized": "parameterized",
|
||||
"protobuf": "protobuf",
|
||||
|
|
|
|||
|
|
@ -62,6 +62,7 @@ from ..roberta.modeling_flax_roberta import (
|
|||
FlaxRobertaForTokenClassification,
|
||||
FlaxRobertaModel,
|
||||
)
|
||||
from ..t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model
|
||||
from ..vit.modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel
|
||||
from .auto_factory import auto_class_factory
|
||||
from .configuration_auto import (
|
||||
|
|
@ -72,6 +73,7 @@ from .configuration_auto import (
|
|||
ElectraConfig,
|
||||
GPT2Config,
|
||||
RobertaConfig,
|
||||
T5Config,
|
||||
ViTConfig,
|
||||
)
|
||||
|
||||
|
|
@ -90,6 +92,7 @@ FLAX_MODEL_MAPPING = OrderedDict(
|
|||
(ElectraConfig, FlaxElectraModel),
|
||||
(CLIPConfig, FlaxCLIPModel),
|
||||
(ViTConfig, FlaxViTModel),
|
||||
(T5Config, FlaxT5Model),
|
||||
]
|
||||
)
|
||||
|
||||
|
|
@ -101,6 +104,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
|||
(BigBirdConfig, FlaxBigBirdForPreTraining),
|
||||
(BartConfig, FlaxBartForConditionalGeneration),
|
||||
(ElectraConfig, FlaxElectraForPreTraining),
|
||||
(T5Config, FlaxT5ForConditionalGeneration),
|
||||
]
|
||||
)
|
||||
|
||||
|
|
@ -115,6 +119,14 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
|||
]
|
||||
)
|
||||
|
||||
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Seq2Seq Causal LM mapping
|
||||
(BartConfig, FlaxBartForConditionalGeneration),
|
||||
(T5Config, FlaxT5ForConditionalGeneration),
|
||||
]
|
||||
)
|
||||
|
||||
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Image-classsification
|
||||
|
|
@ -234,3 +246,9 @@ FlaxAutoModelForNextSentencePrediction = auto_class_factory(
|
|||
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
head_doc="next sentence prediction",
|
||||
)
|
||||
|
||||
FlaxAutoModelForSeq2SeqLM = auto_class_factory(
|
||||
"FlaxAutoModelForSeq2SeqLM",
|
||||
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
head_doc="sequence-to-sequence language modeling",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -229,7 +229,6 @@ class FlaxBartAttention(nn.Module):
|
|||
embed_dim: int
|
||||
num_heads: int
|
||||
dropout: float = 0.0
|
||||
is_decoder: bool = False
|
||||
causal: bool = False
|
||||
bias: bool = True
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
|
@ -510,7 +509,6 @@ class FlaxBartDecoderLayer(nn.Module):
|
|||
embed_dim=self.embed_dim,
|
||||
num_heads=self.config.decoder_attention_heads,
|
||||
dropout=self.config.attention_dropout,
|
||||
is_decoder=True,
|
||||
causal=True,
|
||||
)
|
||||
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
||||
|
|
@ -523,7 +521,6 @@ class FlaxBartDecoderLayer(nn.Module):
|
|||
embed_dim=self.embed_dim,
|
||||
num_heads=self.config.decoder_attention_heads,
|
||||
dropout=self.config.attention_dropout,
|
||||
is_decoder=True,
|
||||
)
|
||||
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
||||
self.fc1 = nn.Dense(
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from typing import TYPE_CHECKING
|
|||
|
||||
from ...file_utils import (
|
||||
_BaseLazyModule,
|
||||
is_flax_available,
|
||||
is_sentencepiece_available,
|
||||
is_tf_available,
|
||||
is_tokenizers_available,
|
||||
|
|
@ -56,6 +57,13 @@ if is_tf_available():
|
|||
"TFT5PreTrainedModel",
|
||||
]
|
||||
|
||||
if is_flax_available():
|
||||
_import_structure["modeling_flax_t5"] = [
|
||||
"FlaxT5ForConditionalGeneration",
|
||||
"FlaxT5Model",
|
||||
"FlaxT5PreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
|
||||
|
|
@ -85,6 +93,10 @@ if TYPE_CHECKING:
|
|||
TFT5PreTrainedModel,
|
||||
)
|
||||
|
||||
if is_flax_available():
|
||||
from .modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
|
||||
|
||||
|
||||
else:
|
||||
import importlib
|
||||
import os
|
||||
|
|
|
|||
1584
src/transformers/models/t5/modeling_flax_t5.py
Normal file
1584
src/transformers/models/t5/modeling_flax_t5.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -570,6 +570,24 @@ class FlaxRobertaPreTrainedModel:
|
|||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxT5ForConditionalGeneration:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxT5Model:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxViTForImageClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ def prepare_bart_inputs_dict(
|
|||
}
|
||||
|
||||
|
||||
class FlaxBartModelTester(unittest.TestCase):
|
||||
class FlaxBartModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
|
|
|
|||
513
tests/test_modeling_flax_t5.py
Normal file
513
tests/test_modeling_flax_t5.py
Normal file
File diff suppressed because one or more lines are too long
|
|
@ -794,6 +794,21 @@ class T5ModelIntegrationTests(unittest.TestCase):
|
|||
def tokenizer(self):
|
||||
return T5Tokenizer.from_pretrained("t5-base")
|
||||
|
||||
@slow
|
||||
def test_small_generation(self):
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5-small").to(torch_device)
|
||||
model.config.max_length = 8
|
||||
model.config.num_beams = 1
|
||||
model.config.do_sample = False
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
|
||||
input_ids = tokenizer("summarize: Hello there", return_tensors="pt").input_ids
|
||||
|
||||
sequences = model.generate(input_ids)
|
||||
|
||||
output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
|
||||
self.assertTrue(output_str == "Hello there!")
|
||||
|
||||
@slow
|
||||
def test_small_integration_test(self):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in a new issue