mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
[Flax] Align FlaxBertForMaskedLM with BertForMaskedLM, implement from_pretrained, init (#9054)
* save intermediate * save intermediate * save intermediate * correct flax bert model file * new module / model naming * make style * almost finish BERT * finish roberta * make fix-copies * delete keys file * last refactor * fixes in run_mlm_flax.py * remove pooled from run_mlm_flax.py` * fix gelu | gelu_new * remove Module from inits * splits * dirty print * preventing warmup_steps == 0 * smaller splits * make fix-copies * dirty print * dirty print * initial_evaluation argument * declaration order fix * proper model initialization/loading * proper initialization * run_mlm_flax improvements: improper model inputs bugfix + automatic dataset splitting + tokenizers parallelism warning + avoiding warmup_steps=0 bug * removed tokenizers warning hack, fixed model re-initialization * reverted training_args.py changes * fix flax from pretrained * improve test in flax * apply sylvains tips * update init * make 0.3.0 compatible * revert tevens changes * revert tevens changes 2 * finalize revert * fix bug * add docs * add pretrained to init * Update src/transformers/modeling_flax_utils.py * fix copies * final improvements Co-authored-by: TevenLeScao <teven.lescao@gmail.com>
This commit is contained in:
parent
51adb97cd6
commit
640e6fe190
14 changed files with 700 additions and 359 deletions
|
|
@ -13,9 +13,10 @@
|
|||
Models
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
The base classes :class:`~transformers.PreTrainedModel` and :class:`~transformers.TFPreTrainedModel` implement the
|
||||
common methods for loading/saving a model either from a local file or directory, or from a pretrained model
|
||||
configuration provided by the library (downloaded from HuggingFace's AWS S3 repository).
|
||||
The base classes :class:`~transformers.PreTrainedModel`, :class:`~transformers.TFPreTrainedModel`, and
|
||||
:class:`~transformers.FlaxPreTrainedModel` implement the common methods for loading/saving a model either from a local
|
||||
file or directory, or from a pretrained model configuration provided by the library (downloaded from HuggingFace's AWS
|
||||
S3 repository).
|
||||
|
||||
:class:`~transformers.PreTrainedModel` and :class:`~transformers.TFPreTrainedModel` also implement a few methods which
|
||||
are common among all the models to:
|
||||
|
|
@ -57,6 +58,13 @@ TFModelUtilsMixin
|
|||
:members:
|
||||
|
||||
|
||||
FlaxPreTrainedModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxPreTrainedModel
|
||||
:members:
|
||||
|
||||
|
||||
Generation
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
|
|
|||
|
|
@ -385,7 +385,7 @@ def training_step(optimizer, batch, dropout_rng):
|
|||
# Hide away tokens which doesn't participate in the optimization
|
||||
token_mask = jnp.where(targets > 0, 1.0, 0.0)
|
||||
|
||||
pooled, logits = model(**batch, params=params, dropout_rng=dropout_rng, train=True)
|
||||
logits = model(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
||||
loss, weight_sum = cross_entropy(logits, targets, token_mask)
|
||||
return loss / weight_sum
|
||||
|
||||
|
|
@ -407,7 +407,7 @@ def eval_step(params, batch):
|
|||
|
||||
# Hide away tokens which doesn't participate in the optimization
|
||||
token_mask = jnp.where(targets > 0, 1.0, 0.0)
|
||||
_, logits = model(**batch, params=params, train=False)
|
||||
logits = model(**batch, params=params, train=False)[0]
|
||||
|
||||
return compute_metrics(logits, targets, token_mask)
|
||||
|
||||
|
|
@ -572,8 +572,13 @@ if __name__ == "__main__":
|
|||
rng = jax.random.PRNGKey(training_args.seed)
|
||||
dropout_rngs = jax.random.split(rng, jax.local_device_count())
|
||||
|
||||
model = FlaxBertForMaskedLM.from_pretrained("bert-base-cased", dtype=jnp.float32, dropout_rate=0.1)
|
||||
model.init(jax.random.PRNGKey(training_args.seed), (training_args.train_batch_size, model.config.max_length))
|
||||
model = FlaxBertForMaskedLM.from_pretrained(
|
||||
"bert-base-cased",
|
||||
dtype=jnp.float32,
|
||||
input_shape=(training_args.train_batch_size, config.max_position_embeddings),
|
||||
seed=training_args.seed,
|
||||
dropout_rate=0.1,
|
||||
)
|
||||
|
||||
# Setup optimizer
|
||||
optimizer = Adam(
|
||||
|
|
|
|||
14
setup.py
14
setup.py
|
|
@ -97,7 +97,7 @@ _deps = [
|
|||
"fastapi",
|
||||
"filelock",
|
||||
"flake8>=3.8.3",
|
||||
"flax==0.2.2",
|
||||
"flax>=0.2.2",
|
||||
"fugashi>=1.0",
|
||||
"ipadic>=1.0.0,<2.0",
|
||||
"isort>=5.5.4",
|
||||
|
|
@ -175,7 +175,7 @@ class DepsTableUpdateCommand(Command):
|
|||
"deps = {",
|
||||
entries,
|
||||
"}",
|
||||
""
|
||||
"",
|
||||
]
|
||||
target = "src/transformers/dependency_versions_table.py"
|
||||
print(f"updating {target}")
|
||||
|
|
@ -232,14 +232,14 @@ extras["dev"] = (
|
|||
# when modifying the following list, make sure to update src/transformers/dependency_versions_check.py
|
||||
install_requires = [
|
||||
deps["dataclasses"] + ";python_version<'3.7'", # dataclasses for Python versions that don't have it
|
||||
deps["filelock"], # filesystem locks, e.g., to prevent parallel downloads
|
||||
deps["filelock"], # filesystem locks, e.g., to prevent parallel downloads
|
||||
deps["numpy"],
|
||||
deps["packaging"], # utilities from PyPA to e.g., compare versions
|
||||
deps["regex"], # for OpenAI GPT
|
||||
deps["requests"], # for downloading models over HTTPS
|
||||
deps["sacremoses"], # for XLM
|
||||
deps["regex"], # for OpenAI GPT
|
||||
deps["requests"], # for downloading models over HTTPS
|
||||
deps["sacremoses"], # for XLM
|
||||
deps["tokenizers"],
|
||||
deps["tqdm"], # progress bars in model download and training scripts
|
||||
deps["tqdm"], # progress bars in model download and training scripts
|
||||
]
|
||||
|
||||
setup(
|
||||
|
|
|
|||
|
|
@ -945,6 +945,7 @@ else:
|
|||
|
||||
|
||||
if is_flax_available():
|
||||
from .modeling_flax_utils import FlaxPreTrainedModel
|
||||
from .models.auto import FLAX_MODEL_MAPPING, FlaxAutoModel
|
||||
from .models.bert import FlaxBertForMaskedLM, FlaxBertModel
|
||||
from .models.roberta import FlaxRobertaModel
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ deps = {
|
|||
"fastapi": "fastapi",
|
||||
"filelock": "filelock",
|
||||
"flake8": "flake8>=3.8.3",
|
||||
"flax": "flax==0.2.2",
|
||||
"flax": "flax>=0.2.2",
|
||||
"fugashi": "fugashi>=1.0",
|
||||
"ipadic": "ipadic>=1.0.0,<2.0",
|
||||
"isort": "isort>=5.5.4",
|
||||
|
|
@ -40,8 +40,8 @@ deps = {
|
|||
"sphinx-rtd-theme": "sphinx-rtd-theme==0.4.3",
|
||||
"sphinx": "sphinx==3.2.1",
|
||||
"starlette": "starlette",
|
||||
"tensorflow-cpu": "tensorflow-cpu>=2.0",
|
||||
"tensorflow": "tensorflow>=2.0",
|
||||
"tensorflow-cpu": "tensorflow-cpu>=2.0,<2.4",
|
||||
"tensorflow": "tensorflow>=2.0,<2.4",
|
||||
"timeout-decorator": "timeout-decorator",
|
||||
"tokenizers": "tokenizers==0.9.4",
|
||||
"torch": "torch>=1.0",
|
||||
|
|
|
|||
|
|
@ -270,6 +270,7 @@ TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
|
|||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
TF2_WEIGHTS_NAME = "tf_model.h5"
|
||||
TF_WEIGHTS_NAME = "model.ckpt"
|
||||
FLAX_WEIGHTS_NAME = "flax_model.msgpack"
|
||||
CONFIG_NAME = "config.json"
|
||||
MODEL_CARD_NAME = "modelcard.json"
|
||||
|
||||
|
|
|
|||
|
|
@ -15,64 +15,65 @@
|
|||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from pickle import UnpicklingError
|
||||
from typing import Dict
|
||||
from typing import Dict, Set, Tuple, Union
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.serialization import to_bytes
|
||||
from flax.traverse_util import unflatten_dict
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.serialization import from_bytes, to_bytes
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax.random import PRNGKey
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .file_utils import WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url
|
||||
from .file_utils import FLAX_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@jax.jit
|
||||
def gelu(x):
|
||||
r"""
|
||||
Gaussian error linear unit activation function.
|
||||
|
||||
Computes the element-wise function:
|
||||
|
||||
.. math::
|
||||
\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left(
|
||||
\sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)
|
||||
|
||||
We explicitly use the approximation rather than the exact formulation for speed. For more information, see
|
||||
`Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_, section 2.
|
||||
"""
|
||||
return x * 0.5 * (1.0 + jax.lax.erf(x / jnp.sqrt(2.0)))
|
||||
|
||||
|
||||
ACT2FN = {
|
||||
"gelu": nn.gelu,
|
||||
"relu": nn.relu,
|
||||
"silu": nn.swish,
|
||||
"swish": nn.swish,
|
||||
"gelu_new": gelu,
|
||||
"gelu_new": partial(nn.gelu, approximate=True),
|
||||
}
|
||||
|
||||
|
||||
class FlaxPreTrainedModel(ABC):
|
||||
r"""
|
||||
Base class for all models.
|
||||
|
||||
:class:`~transformers.FlaxPreTrainedModel` takes care of storing the configuration of the models and handles
|
||||
methods for loading, downloading and saving models.
|
||||
|
||||
Class attributes (overridden by derived classes):
|
||||
|
||||
- **config_class** (:class:`~transformers.PretrainedConfig`) -- A subclass of
|
||||
:class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
|
||||
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
|
||||
derived classes of the same architecture adding modules on top of the base model.
|
||||
"""
|
||||
config_class = None
|
||||
pretrained_model_archive_map = {}
|
||||
base_model_prefix = ""
|
||||
model_class = None
|
||||
|
||||
def __init__(
|
||||
self, config: PretrainedConfig, module: nn.Module, params: Dict, seed: int = 0, dtype: jnp.dtype = jnp.float32
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
module: nn.Module,
|
||||
input_shape: Tuple = (1, 1),
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
):
|
||||
if config is None:
|
||||
raise ValueError("config cannot be None")
|
||||
|
||||
if params is None:
|
||||
raise ValueError("state cannot be None")
|
||||
if module is None:
|
||||
raise ValueError("module cannot be None")
|
||||
|
||||
# Those are private to be exposed as typed property on derived classes.
|
||||
self._config = config
|
||||
|
|
@ -80,9 +81,18 @@ class FlaxPreTrainedModel(ABC):
|
|||
|
||||
# Those are public as their type is generic to every derived classes.
|
||||
self.key = PRNGKey(seed)
|
||||
self.params = params
|
||||
self.dtype = dtype
|
||||
|
||||
# randomely initialized parameters
|
||||
random_params = self.init(self.key, input_shape)
|
||||
|
||||
# save required_params as set
|
||||
self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
|
||||
self.params = random_params
|
||||
|
||||
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> Dict:
|
||||
raise NotImplementedError(f"init method has to be implemented for {self}")
|
||||
|
||||
@property
|
||||
def config(self) -> PretrainedConfig:
|
||||
return self._config
|
||||
|
|
@ -91,24 +101,130 @@ class FlaxPreTrainedModel(ABC):
|
|||
def module(self) -> nn.Module:
|
||||
return self._module
|
||||
|
||||
@property
|
||||
def params(self) -> Union[Dict, FrozenDict]:
|
||||
return self._params
|
||||
|
||||
@property
|
||||
def required_params(self) -> Set:
|
||||
return self._required_params
|
||||
|
||||
@params.setter
|
||||
def params(self, params: Union[Dict, FrozenDict]):
|
||||
if isinstance(params, FrozenDict):
|
||||
params = unfreeze(params)
|
||||
param_keys = set(flatten_dict(params).keys())
|
||||
if len(self.required_params - param_keys) > 0:
|
||||
raise ValueError(
|
||||
"Some parameters are missing. Make sure that `params` include the following "
|
||||
f"parameters {self.required_params - param_keys}"
|
||||
)
|
||||
self._params = freeze(params)
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def convert_from_pytorch(pt_state: Dict, config: PretrainedConfig) -> Dict:
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, dtype: jnp.dtype = jnp.float32, *model_args, **kwargs):
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
*model_args,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
r"""
|
||||
Instantiate a pretrained Flax model from a pre-trained model configuration.
|
||||
Instantiate a pretrained flax model from a pre-trained model configuration.
|
||||
|
||||
The warning `Weights from XXX not initialized from pretrained model` means that the weights of XXX do not come
|
||||
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
|
||||
task.
|
||||
|
||||
The warning `Weights from XXX not used in YYY` means that the layer XXX is not used by YYY, therefore those
|
||||
weights are discarded.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
|
||||
Can be either:
|
||||
|
||||
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
|
||||
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
|
||||
a user or organization name, like ``dbmdz/bert-base-german-cased``.
|
||||
- A path to a `directory` containing model weights saved using
|
||||
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
|
||||
- A path or url to a `pt index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In this
|
||||
case, ``from_pt`` should be set to :obj:`True`.
|
||||
model_args (sequence of positional arguments, `optional`):
|
||||
All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
|
||||
config (:obj:`Union[PretrainedConfig, str, os.PathLike]`, `optional`):
|
||||
Can be either:
|
||||
|
||||
- an instance of a class derived from :class:`~transformers.PretrainedConfig`,
|
||||
- a string or path valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`.
|
||||
|
||||
Configuration for the model to use instead of an automatically loaded configuation. Configuration can
|
||||
be automatically loaded when:
|
||||
|
||||
- The model is a model provided by the library (loaded with the `model id` string of a pretrained
|
||||
model).
|
||||
- The model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded
|
||||
by supplying the save directory.
|
||||
- The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
|
||||
configuration JSON file named `config.json` is found in the directory.
|
||||
cache_dir (:obj:`Union[str, os.PathLike]`, `optional`):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||
standard cache should not be used.
|
||||
from_pt (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Load the model weights from a PyTorch checkpoint save file (see docstring of
|
||||
``pretrained_model_name_or_path`` argument).
|
||||
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
||||
file exists.
|
||||
proxies (:obj:`Dict[str, str], `optional`):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to only look at local files (i.e., do not try to download the model).
|
||||
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||
identifier allowed by git.
|
||||
kwargs (remaining dictionary of keyword arguments, `optional`):
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
|
||||
automatically loaded:
|
||||
|
||||
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
|
||||
underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
|
||||
already been done)
|
||||
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
|
||||
initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
|
||||
``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
|
||||
with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
|
||||
attribute will be passed to the underlying model's ``__init__`` function.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import BertConfig, FlaxBertModel
|
||||
>>> # Download model and configuration from huggingface.co and cache.
|
||||
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
|
||||
>>> # Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable).
|
||||
>>> model = FlaxBertModel.from_pretrained('./test/saved_model/')
|
||||
>>> # Loading from a PyTorch checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
|
||||
>>> config = BertConfig.from_json_file('./pt_model/config.json')
|
||||
>>> model = FlaxBertModel.from_pretrained('./pt_model/pytorch_model.bin', from_pt=True, config=config)
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
# state_dict = kwargs.pop("state_dict", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
# from_tf = kwargs.pop("from_tf", False)
|
||||
from_pt = kwargs.pop("from_pt", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
# output_loading_info = kwargs.pop("output_loading_info", False)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
|
||||
|
|
@ -135,10 +251,28 @@ class FlaxPreTrainedModel(ABC):
|
|||
|
||||
# Load model
|
||||
if pretrained_model_name_or_path is not None:
|
||||
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
||||
# Load from a PyTorch checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
|
||||
# Load from a Flax checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
"Error no file named {} found in directory {} or `from_pt` set to False".format(
|
||||
[FLAX_WEIGHTS_NAME, WEIGHTS_NAME],
|
||||
pretrained_model_name_or_path,
|
||||
)
|
||||
)
|
||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
archive_file = pretrained_model_name_or_path
|
||||
else:
|
||||
archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision)
|
||||
archive_file = hf_bucket_url(
|
||||
pretrained_model_name_or_path,
|
||||
filename=WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME,
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
|
|
@ -169,31 +303,96 @@ class FlaxPreTrainedModel(ABC):
|
|||
# Instantiate model.
|
||||
with open(resolved_archive_file, "rb") as state_f:
|
||||
try:
|
||||
from flax.serialization import from_bytes
|
||||
|
||||
state = from_bytes(cls.model_class, state_f)
|
||||
except TypeError:
|
||||
try:
|
||||
if from_pt:
|
||||
import torch
|
||||
|
||||
state = torch.load(state_f)
|
||||
state = {k: v.numpy() for k, v in state.items()}
|
||||
state = cls.convert_from_pytorch(state, config)
|
||||
state = unflatten_dict({tuple(k.split(".")[1:]): v for k, v in state.items()})
|
||||
except UnpicklingError:
|
||||
raise EnvironmentError(
|
||||
f"Unable to convert model {archive_file} to Flax deserializable object. "
|
||||
"Supported format are PyTorch archive or Flax msgpack"
|
||||
)
|
||||
|
||||
return cls(config, state, *model_args, **model_kwargs)
|
||||
state = convert_state_dict_from_pt(cls, state, config)
|
||||
else:
|
||||
state = from_bytes(cls, state_f.read())
|
||||
except UnpicklingError:
|
||||
raise EnvironmentError(
|
||||
f"Unable to convert pytorch model {archive_file} to Flax deserializable object. "
|
||||
)
|
||||
|
||||
def save_pretrained(self, folder):
|
||||
folder_abs = os.path.abspath(folder)
|
||||
# init random models
|
||||
model = cls(config, *model_args, **model_kwargs)
|
||||
|
||||
if not os.path.exists(folder_abs):
|
||||
os.mkdir(folder_abs)
|
||||
# if model is base model only use model_prefix key
|
||||
if cls.base_model_prefix not in dict(model.params) and cls.base_model_prefix in state:
|
||||
state = state[cls.base_model_prefix]
|
||||
|
||||
with open(os.path.join(folder_abs, f"{self._config.model_type}.flax", "wb")) as f:
|
||||
# flatten dicts
|
||||
state = flatten_dict(state)
|
||||
random_state = flatten_dict(unfreeze(model.params))
|
||||
|
||||
missing_keys = model.required_params - set(state.keys())
|
||||
unexpected_keys = set(state.keys()) - model.required_params
|
||||
|
||||
# add missing keys as random parameters
|
||||
for missing_key in missing_keys:
|
||||
state[missing_key] = random_state[missing_key]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
|
||||
f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
|
||||
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
|
||||
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
|
||||
f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
|
||||
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
|
||||
)
|
||||
else:
|
||||
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
||||
|
||||
if len(missing_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
|
||||
f"and are newly initialized: {missing_keys}\n"
|
||||
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
|
||||
f"If your task is similar to the task the model of the checkpoint was trained on, "
|
||||
f"you can already use {model.__class__.__name__} for predictions without further training."
|
||||
)
|
||||
|
||||
# set correct parameters
|
||||
model.params = unflatten_dict(state)
|
||||
return model
|
||||
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
|
||||
"""
|
||||
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
||||
`:func:`~transformers.FlaxPreTrainedModel.from_pretrained`` class method
|
||||
|
||||
Arguments:
|
||||
save_directory (:obj:`str` or :obj:`os.PathLike`):
|
||||
Directory to which to save. Will be created if it doesn't exist.
|
||||
"""
|
||||
if os.path.isfile(save_directory):
|
||||
logger.error("Provided path ({}) should be a directory, not a file".format(save_directory))
|
||||
return
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
# get abs dir
|
||||
save_directory = os.path.abspath(save_directory)
|
||||
# save config as well
|
||||
self.config.save_pretrained(save_directory)
|
||||
|
||||
# save model
|
||||
with open(os.path.join(save_directory, FLAX_WEIGHTS_NAME), "wb") as f:
|
||||
model_bytes = to_bytes(self.params)
|
||||
f.write(model_bytes)
|
||||
|
||||
|
||||
def convert_state_dict_from_pt(model_class: ABC, state: Dict, config: PretrainedConfig):
|
||||
"""
|
||||
Converts a PyTorch parameter state dict to an equivalent Flax parameter state dict
|
||||
"""
|
||||
state = {k: v.numpy() for k, v in state.items()}
|
||||
state = model_class.convert_from_pytorch(state, config)
|
||||
state = unflatten_dict({tuple(k.split(".")): v for k, v in state.items()})
|
||||
return state
|
||||
|
|
|
|||
|
|
@ -27,15 +27,6 @@ from .configuration_auto import AutoConfig, BertConfig, RobertaConfig
|
|||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict(
|
||||
(key, value)
|
||||
for pretrained_map in [
|
||||
FlaxBertModel.pretrained_model_archive_map,
|
||||
FlaxRobertaModel.pretrained_model_archive_map,
|
||||
]
|
||||
for key, value, in pretrained_map.items()
|
||||
)
|
||||
|
||||
FLAX_MODEL_MAPPING = OrderedDict(
|
||||
[
|
||||
(RobertaConfig, FlaxRobertaModel),
|
||||
|
|
@ -114,10 +105,9 @@ class FlaxAutoModel(object):
|
|||
organization name, like ``dbmdz/bert-base-german-cased``.
|
||||
- a path to a `directory` containing model weights saved using
|
||||
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
|
||||
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this
|
||||
case, ``from_tf`` should be set to True and a configuration object should be provided as ``config``
|
||||
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model
|
||||
using the provided conversion scripts and loading the PyTorch model afterwards.
|
||||
- a path or url to a `pytorch index checkpoint file` (e.g. `./pt_model/pytorch_model.bin`). In this
|
||||
case, ``from_pt`` should be set to True and a configuration object should be provided as ``config``
|
||||
argument.
|
||||
|
||||
model_args: (`optional`) Sequence of positional arguments:
|
||||
All remaining positional arguments will be passed to the underlying model's ``__init__`` method
|
||||
|
|
@ -133,13 +123,6 @@ class FlaxAutoModel(object):
|
|||
- the model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
|
||||
configuration JSON file named `config.json` is found in the directory.
|
||||
|
||||
state_dict: (`optional`) dict:
|
||||
an optional state dictionary for the model to use instead of a state dictionary loaded from saved
|
||||
weights file. This option can be used if you want to create a model from a pretrained configuration but
|
||||
load your own weights. In this case though, you should check if using
|
||||
:func:`~transformers.FlaxPreTrainedModel.save_pretrained` and
|
||||
:func:`~transformers.FlaxPreTrainedModel.from_pretrained` is not a simpler option.
|
||||
|
||||
cache_dir: (`optional`) string:
|
||||
Path to a directory in which a downloaded pre-trained model configuration should be cached if the
|
||||
standard cache should not be used.
|
||||
|
|
|
|||
|
|
@ -20,10 +20,11 @@ import numpy as np
|
|||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from jax.random import PRNGKey
|
||||
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from ...modeling_flax_utils import FlaxPreTrainedModel, gelu
|
||||
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
|
||||
from ...utils import logging
|
||||
from .configuration_bert import BertConfig
|
||||
|
||||
|
|
@ -205,7 +206,7 @@ class FlaxBertAttention(nn.Module):
|
|||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
|
||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
||||
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
||||
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
||||
|
|
@ -219,27 +220,28 @@ class FlaxBertAttention(nn.Module):
|
|||
bias_init=jax.nn.initializers.zeros,
|
||||
name="self",
|
||||
dtype=self.dtype,
|
||||
)(hidden_state, attention_mask)
|
||||
)(hidden_states, attention_mask)
|
||||
|
||||
layer_norm = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_state)
|
||||
layer_norm = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_states)
|
||||
return layer_norm
|
||||
|
||||
|
||||
class FlaxBertIntermediate(nn.Module):
|
||||
output_size: int
|
||||
hidden_act: str = "gelu"
|
||||
kernel_init_scale: float = 0.2
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state):
|
||||
# TODO: Add ACT2FN reference to change activation function
|
||||
dense = nn.Dense(
|
||||
def __call__(self, hidden_states):
|
||||
hidden_states = nn.Dense(
|
||||
features=self.output_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
||||
name="dense",
|
||||
dtype=self.dtype,
|
||||
)(hidden_state)
|
||||
return gelu(dense)
|
||||
)(hidden_states)
|
||||
hidden_states = ACT2FN[self.hidden_act](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxBertOutput(nn.Module):
|
||||
|
|
@ -249,27 +251,28 @@ class FlaxBertOutput(nn.Module):
|
|||
|
||||
@nn.compact
|
||||
def __call__(self, intermediate_output, attention_output, deterministic: bool = True):
|
||||
hidden_state = nn.Dense(
|
||||
hidden_states = nn.Dense(
|
||||
attention_output.shape[-1],
|
||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
||||
name="dense",
|
||||
dtype=self.dtype,
|
||||
)(intermediate_output)
|
||||
hidden_state = nn.Dropout(rate=self.dropout_rate)(hidden_state, deterministic=deterministic)
|
||||
hidden_state = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_state + attention_output)
|
||||
return hidden_state
|
||||
hidden_states = nn.Dropout(rate=self.dropout_rate)(hidden_states, deterministic=deterministic)
|
||||
hidden_states = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_states + attention_output)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxBertLayer(nn.Module):
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
hidden_act: str = "gelu"
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
|
||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||
attention = FlaxBertAttention(
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
|
|
@ -277,9 +280,13 @@ class FlaxBertLayer(nn.Module):
|
|||
dropout_rate=self.dropout_rate,
|
||||
name="attention",
|
||||
dtype=self.dtype,
|
||||
)(hidden_state, attention_mask, deterministic=deterministic)
|
||||
)(hidden_states, attention_mask, deterministic=deterministic)
|
||||
intermediate = FlaxBertIntermediate(
|
||||
self.intermediate_size, kernel_init_scale=self.kernel_init_scale, name="intermediate", dtype=self.dtype
|
||||
self.intermediate_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
hidden_act=self.hidden_act,
|
||||
name="intermediate",
|
||||
dtype=self.dtype,
|
||||
)(attention)
|
||||
output = FlaxBertOutput(
|
||||
kernel_init_scale=self.kernel_init_scale, dropout_rate=self.dropout_rate, name="output", dtype=self.dtype
|
||||
|
|
@ -297,6 +304,7 @@ class FlaxBertLayerCollection(nn.Module):
|
|||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
hidden_act: str = "gelu"
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
|
@ -316,6 +324,7 @@ class FlaxBertLayerCollection(nn.Module):
|
|||
self.intermediate_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
hidden_act=self.hidden_act,
|
||||
name=f"{i}",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
|
@ -328,22 +337,24 @@ class FlaxBertEncoder(nn.Module):
|
|||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
hidden_act: str = "gelu"
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
|
||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||
layer = FlaxBertLayerCollection(
|
||||
self.num_layers,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
name="layer",
|
||||
dtype=self.dtype,
|
||||
)(hidden_state, attention_mask, deterministic=deterministic)
|
||||
)(hidden_states, attention_mask, deterministic=deterministic)
|
||||
return layer
|
||||
|
||||
|
||||
|
|
@ -352,10 +363,10 @@ class FlaxBertPooler(nn.Module):
|
|||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state):
|
||||
cls_token = hidden_state[:, 0]
|
||||
def __call__(self, hidden_states):
|
||||
cls_token = hidden_states[:, 0]
|
||||
out = nn.Dense(
|
||||
hidden_state.shape[-1],
|
||||
hidden_states.shape[-1],
|
||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
||||
name="dense",
|
||||
dtype=self.dtype,
|
||||
|
|
@ -363,62 +374,20 @@ class FlaxBertPooler(nn.Module):
|
|||
return nn.tanh(out)
|
||||
|
||||
|
||||
class FlaxBertModule(nn.Module):
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
type_vocab_size: int
|
||||
max_length: int
|
||||
num_encoder_layers: int
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
|
||||
|
||||
# Embedding
|
||||
embeddings = FlaxBertEmbeddings(
|
||||
self.vocab_size,
|
||||
self.hidden_size,
|
||||
self.type_vocab_size,
|
||||
self.max_length,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
name="embeddings",
|
||||
dtype=self.dtype,
|
||||
)(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic)
|
||||
|
||||
# N stacked encoding layers
|
||||
encoder = FlaxBertEncoder(
|
||||
self.num_encoder_layers,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.intermediate_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
name="encoder",
|
||||
dtype=self.dtype,
|
||||
)(embeddings, attention_mask, deterministic=deterministic)
|
||||
|
||||
pooled = FlaxBertPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder)
|
||||
return encoder, pooled
|
||||
|
||||
|
||||
class FlaxBertPredictionHeadTransform(nn.Module):
|
||||
hidden_act: str = "gelu"
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_states):
|
||||
hidden_states = nn.Dense(hidden_states.shape[-1], name="dense", dtype=self.dtype)(hidden_states)
|
||||
hidden_states = nn.elu(hidden_states) # TODO: ACT2FN[config.hidden_act]
|
||||
return FlaxBertLayerNorm(name="LayerNorm", dtype=self.dtype)(hidden_states)
|
||||
hidden_states = ACT2FN[self.hidden_act](hidden_states)
|
||||
return FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_states)
|
||||
|
||||
|
||||
class FlaxBertLMPredictionHead(nn.Module):
|
||||
vocab_size: int
|
||||
hidden_act: str = "gelu"
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
@nn.compact
|
||||
|
|
@ -428,64 +397,57 @@ class FlaxBertLMPredictionHead(nn.Module):
|
|||
# Need a link between the two variables so that the bias is correctly
|
||||
# resized with `resize_token_embeddings`
|
||||
|
||||
hidden_states = FlaxBertPredictionHeadTransform(name="transform", dtype=self.dtype)(hidden_states)
|
||||
hidden_states = FlaxBertPredictionHeadTransform(
|
||||
name="transform", hidden_act=self.hidden_act, dtype=self.dtype
|
||||
)(hidden_states)
|
||||
hidden_states = nn.Dense(self.vocab_size, name="decoder", dtype=self.dtype)(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxBertOnlyMLMHead(nn.Module):
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
head_size: int
|
||||
num_heads: int
|
||||
num_encoder_layers: int
|
||||
type_vocab_size: int
|
||||
max_length: int
|
||||
dropout_rate: float = 0.0
|
||||
hidden_act: str = "gelu"
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
@nn.compact
|
||||
def __call__(
|
||||
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
||||
):
|
||||
# Model
|
||||
encoder, pooled = FlaxBertModule(
|
||||
vocab_size=self.vocab_size,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=self.intermediate_size,
|
||||
head_size=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
num_encoder_layers=self.num_encoder_layers,
|
||||
max_length=self.max_length,
|
||||
dropout_rate=self.dropout_rate,
|
||||
dtype=self.dtype,
|
||||
)(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
|
||||
|
||||
# Compute the prediction scores
|
||||
encoder = nn.Dropout(rate=self.dropout_rate)(encoder, deterministic=deterministic)
|
||||
logits = FlaxBertLMPredictionHead(vocab_size=self.vocab_size, name="predictions", dtype=self.dtype)(encoder)
|
||||
|
||||
return logits, pooled
|
||||
def __call__(self, hidden_states):
|
||||
hidden_states = FlaxBertLMPredictionHead(
|
||||
vocab_size=self.vocab_size, hidden_act=self.hidden_act, name="predictions", dtype=self.dtype
|
||||
)(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class FlaxBertModel(FlaxPreTrainedModel):
|
||||
class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
|
||||
"""
|
||||
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
||||
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
||||
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
||||
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
model_class = FlaxBertModule
|
||||
config_class = BertConfig
|
||||
base_model_prefix = "bert"
|
||||
|
||||
def _check_inputs(self, input_ids, attention_mask, token_type_ids, position_ids):
|
||||
if token_type_ids is None:
|
||||
token_type_ids = jnp.ones_like(input_ids)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = jnp.arange(jnp.atleast_2d(input_ids).shape[-1])
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
|
||||
return input_ids, attention_mask, token_type_ids, position_ids
|
||||
|
||||
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
jnp.zeros(input_shape, dtype="i4"), None, None, None
|
||||
)
|
||||
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
|
||||
|
||||
@staticmethod
|
||||
def convert_from_pytorch(pt_state: Dict, config: BertConfig) -> Dict:
|
||||
jax_state = dict(pt_state)
|
||||
|
|
@ -501,6 +463,11 @@ class FlaxBertModel(FlaxPreTrainedModel):
|
|||
key = key.replace("weight", "kernel")
|
||||
jax_state[key] = tensor
|
||||
|
||||
if "decoder.weight" in key:
|
||||
del jax_state[key]
|
||||
key = key.replace("weight", "kernel")
|
||||
jax_state[key] = tensor.T
|
||||
|
||||
# SelfAttention needs also to replace "weight" by "kernel"
|
||||
if {"query", "key", "value"} & key_parts:
|
||||
|
||||
|
|
@ -526,7 +493,7 @@ class FlaxBertModel(FlaxPreTrainedModel):
|
|||
jax_state[key] = tensor
|
||||
|
||||
# There are some transposed parameters w.r.t their PyTorch counterpart
|
||||
if "intermediate.dense.kernel" in key or "output.dense.kernel" in key:
|
||||
if "intermediate.dense.kernel" in key or "output.dense.kernel" in key or "transform.dense.kernel" in key:
|
||||
jax_state[key] = tensor.T
|
||||
|
||||
# Self Attention output projection needs to be transposed
|
||||
|
|
@ -539,6 +506,11 @@ class FlaxBertModel(FlaxPreTrainedModel):
|
|||
if "pooler.dense.kernel" in key:
|
||||
jax_state[key] = tensor.T
|
||||
|
||||
# Hack to correctly load some pytorch models
|
||||
if "predictions.bias" in key:
|
||||
del jax_state[key]
|
||||
jax_state[".".join(key.split(".")[:2]) + ".decoder.bias"] = tensor
|
||||
|
||||
# Handle LayerNorm conversion
|
||||
if "LayerNorm" in key:
|
||||
del jax_state[key]
|
||||
|
|
@ -555,7 +527,22 @@ class FlaxBertModel(FlaxPreTrainedModel):
|
|||
|
||||
return jax_state
|
||||
|
||||
def __init__(self, config: BertConfig, state: dict, seed: int = 0, dtype: jnp.dtype = jnp.float32):
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class FlaxBertModel(FlaxBertPreTrainedModel):
|
||||
"""
|
||||
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
||||
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
||||
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
||||
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||
):
|
||||
module = FlaxBertModule(
|
||||
vocab_size=config.vocab_size,
|
||||
hidden_size=config.hidden_size,
|
||||
|
|
@ -566,10 +553,12 @@ class FlaxBertModel(FlaxPreTrainedModel):
|
|||
head_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
dropout_rate=config.hidden_dropout_prob,
|
||||
hidden_act=config.hidden_act,
|
||||
dtype=dtype,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
super().__init__(config, module, state, seed)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
|
|
@ -601,34 +590,62 @@ class FlaxBertModel(FlaxPreTrainedModel):
|
|||
rngs=rngs,
|
||||
)
|
||||
|
||||
def _check_inputs(self, input_ids, attention_mask, token_type_ids, position_ids):
|
||||
if token_type_ids is None:
|
||||
token_type_ids = jnp.ones_like(input_ids)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = jnp.arange(jnp.atleast_2d(input_ids).shape[-1])
|
||||
class FlaxBertModule(nn.Module):
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
type_vocab_size: int
|
||||
max_length: int
|
||||
num_encoder_layers: int
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
hidden_act: str = "gelu"
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
add_pooling_layer: bool = True
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
@nn.compact
|
||||
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
|
||||
|
||||
return input_ids, attention_mask, token_type_ids, position_ids
|
||||
# Embedding
|
||||
embeddings = FlaxBertEmbeddings(
|
||||
self.vocab_size,
|
||||
self.hidden_size,
|
||||
self.type_vocab_size,
|
||||
self.max_length,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
name="embeddings",
|
||||
dtype=self.dtype,
|
||||
)(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic)
|
||||
|
||||
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple):
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
jnp.zeros(input_shape, dtype="i4"), None, None, None
|
||||
)
|
||||
# N stacked encoding layers
|
||||
encoder = FlaxBertEncoder(
|
||||
self.num_encoder_layers,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.intermediate_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
hidden_act=self.hidden_act,
|
||||
name="encoder",
|
||||
dtype=self.dtype,
|
||||
)(embeddings, attention_mask, deterministic=deterministic)
|
||||
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
if not self.add_pooling_layer:
|
||||
return encoder
|
||||
|
||||
self.params = self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
|
||||
pooled = FlaxBertPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder)
|
||||
return encoder, pooled
|
||||
|
||||
|
||||
class FlaxBertForMaskedLM(FlaxBertModel):
|
||||
def __init__(self, config: BertConfig, state: dict, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs):
|
||||
super().__init__(config, state, seed, dtype)
|
||||
|
||||
self._module = FlaxBertOnlyMLMHead(
|
||||
class FlaxBertForMaskedLM(FlaxBertPreTrainedModel):
|
||||
def __init__(
|
||||
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||
):
|
||||
module = FlaxBertForMaskedLMModule(
|
||||
vocab_size=config.vocab_size,
|
||||
type_vocab_size=config.type_vocab_size,
|
||||
hidden_size=config.hidden_size,
|
||||
|
|
@ -636,10 +653,13 @@ class FlaxBertForMaskedLM(FlaxBertModel):
|
|||
head_size=config.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_encoder_layers=config.num_hidden_layers,
|
||||
max_length=config.max_length,
|
||||
max_length=config.max_position_embeddings,
|
||||
hidden_act=config.hidden_act,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
|
|
@ -659,7 +679,7 @@ class FlaxBertForMaskedLM(FlaxBertModel):
|
|||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
pooled, logits = self.module.apply(
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
|
|
@ -669,4 +689,45 @@ class FlaxBertForMaskedLM(FlaxBertModel):
|
|||
rngs=rngs,
|
||||
)
|
||||
|
||||
return logits, pooled
|
||||
|
||||
class FlaxBertForMaskedLMModule(nn.Module):
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
head_size: int
|
||||
num_heads: int
|
||||
num_encoder_layers: int
|
||||
type_vocab_size: int
|
||||
max_length: int
|
||||
hidden_act: str
|
||||
dropout_rate: float = 0.0
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
@nn.compact
|
||||
def __call__(
|
||||
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
||||
):
|
||||
# Model
|
||||
encoder = FlaxBertModule(
|
||||
vocab_size=self.vocab_size,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=self.intermediate_size,
|
||||
head_size=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
num_encoder_layers=self.num_encoder_layers,
|
||||
max_length=self.max_length,
|
||||
dropout_rate=self.dropout_rate,
|
||||
hidden_act=self.hidden_act,
|
||||
dtype=self.dtype,
|
||||
add_pooling_layer=False,
|
||||
name="bert",
|
||||
)(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
|
||||
|
||||
# Compute the prediction scores
|
||||
encoder = nn.Dropout(rate=self.dropout_rate)(encoder, deterministic=deterministic)
|
||||
logits = FlaxBertOnlyMLMHead(
|
||||
vocab_size=self.vocab_size, hidden_act=self.hidden_act, name="cls", dtype=self.dtype
|
||||
)(encoder)
|
||||
|
||||
return (logits,)
|
||||
|
|
|
|||
|
|
@ -19,10 +19,11 @@ import numpy as np
|
|||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from jax.random import PRNGKey
|
||||
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from ...modeling_flax_utils import FlaxPreTrainedModel, gelu
|
||||
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
|
||||
from ...utils import logging
|
||||
from .configuration_roberta import RobertaConfig
|
||||
|
||||
|
|
@ -33,6 +34,23 @@ _CONFIG_FOR_DOC = "RobertaConfig"
|
|||
_TOKENIZER_FOR_DOC = "RobertaTokenizer"
|
||||
|
||||
|
||||
def create_position_ids_from_input_ids(input_ids, padding_idx):
|
||||
"""
|
||||
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
|
||||
are ignored. This is modified from fairseq's `utils.make_positions`.
|
||||
|
||||
Args:
|
||||
input_ids: jnp.ndarray
|
||||
padding_idx: int
|
||||
|
||||
Returns: jnp.ndarray
|
||||
"""
|
||||
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
||||
mask = (input_ids != padding_idx).astype("i4")
|
||||
incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask
|
||||
return incremental_indices.astype("i4") + padding_idx
|
||||
|
||||
|
||||
ROBERTA_START_DOCSTRING = r"""
|
||||
|
||||
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
|
||||
|
|
@ -208,7 +226,7 @@ class FlaxRobertaAttention(nn.Module):
|
|||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
|
||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
||||
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
||||
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
||||
|
|
@ -222,28 +240,29 @@ class FlaxRobertaAttention(nn.Module):
|
|||
bias_init=jax.nn.initializers.zeros,
|
||||
name="self",
|
||||
dtype=self.dtype,
|
||||
)(hidden_state, attention_mask)
|
||||
)(hidden_states, attention_mask)
|
||||
|
||||
layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_state)
|
||||
layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_states)
|
||||
return layer_norm
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta
|
||||
class FlaxRobertaIntermediate(nn.Module):
|
||||
output_size: int
|
||||
hidden_act: str = "gelu"
|
||||
kernel_init_scale: float = 0.2
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state):
|
||||
# TODO: Add ACT2FN reference to change activation function
|
||||
dense = nn.Dense(
|
||||
def __call__(self, hidden_states):
|
||||
hidden_states = nn.Dense(
|
||||
features=self.output_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
||||
name="dense",
|
||||
dtype=self.dtype,
|
||||
)(hidden_state)
|
||||
return gelu(dense)
|
||||
)(hidden_states)
|
||||
hidden_states = ACT2FN[self.hidden_act](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Roberta
|
||||
|
|
@ -254,27 +273,28 @@ class FlaxRobertaOutput(nn.Module):
|
|||
|
||||
@nn.compact
|
||||
def __call__(self, intermediate_output, attention_output, deterministic: bool = True):
|
||||
hidden_state = nn.Dense(
|
||||
hidden_states = nn.Dense(
|
||||
attention_output.shape[-1],
|
||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
||||
name="dense",
|
||||
dtype=self.dtype,
|
||||
)(intermediate_output)
|
||||
hidden_state = nn.Dropout(rate=self.dropout_rate)(hidden_state, deterministic=deterministic)
|
||||
hidden_state = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_state + attention_output)
|
||||
return hidden_state
|
||||
hidden_states = nn.Dropout(rate=self.dropout_rate)(hidden_states, deterministic=deterministic)
|
||||
hidden_states = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_states + attention_output)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxRobertaLayer(nn.Module):
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
hidden_act: str = "gelu"
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
|
||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||
attention = FlaxRobertaAttention(
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
|
|
@ -282,10 +302,11 @@ class FlaxRobertaLayer(nn.Module):
|
|||
dropout_rate=self.dropout_rate,
|
||||
name="attention",
|
||||
dtype=self.dtype,
|
||||
)(hidden_state, attention_mask, deterministic=deterministic)
|
||||
)(hidden_states, attention_mask, deterministic=deterministic)
|
||||
intermediate = FlaxRobertaIntermediate(
|
||||
self.intermediate_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
hidden_act=self.hidden_act,
|
||||
name="intermediate",
|
||||
dtype=self.dtype,
|
||||
)(attention)
|
||||
|
|
@ -306,6 +327,7 @@ class FlaxRobertaLayerCollection(nn.Module):
|
|||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
hidden_act: str = "gelu"
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
|
@ -325,6 +347,7 @@ class FlaxRobertaLayerCollection(nn.Module):
|
|||
self.intermediate_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
hidden_act=self.hidden_act,
|
||||
name=f"{i}",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
|
@ -338,22 +361,24 @@ class FlaxRobertaEncoder(nn.Module):
|
|||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
hidden_act: str = "gelu"
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
|
||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||
layer = FlaxRobertaLayerCollection(
|
||||
self.num_layers,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
name="layer",
|
||||
dtype=self.dtype,
|
||||
)(hidden_state, attention_mask, deterministic=deterministic)
|
||||
)(hidden_states, attention_mask, deterministic=deterministic)
|
||||
return layer
|
||||
|
||||
|
||||
|
|
@ -363,10 +388,10 @@ class FlaxRobertaPooler(nn.Module):
|
|||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state):
|
||||
cls_token = hidden_state[:, 0]
|
||||
def __call__(self, hidden_states):
|
||||
cls_token = hidden_states[:, 0]
|
||||
out = nn.Dense(
|
||||
hidden_state.shape[-1],
|
||||
hidden_states.shape[-1],
|
||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
||||
name="dense",
|
||||
dtype=self.dtype,
|
||||
|
|
@ -374,64 +399,12 @@ class FlaxRobertaPooler(nn.Module):
|
|||
return nn.tanh(out)
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta
|
||||
class FlaxRobertaModule(nn.Module):
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
type_vocab_size: int
|
||||
max_length: int
|
||||
num_encoder_layers: int
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
|
||||
|
||||
# Embedding
|
||||
embeddings = FlaxRobertaEmbeddings(
|
||||
self.vocab_size,
|
||||
self.hidden_size,
|
||||
self.type_vocab_size,
|
||||
self.max_length,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
name="embeddings",
|
||||
dtype=self.dtype,
|
||||
)(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic)
|
||||
|
||||
# N stacked encoding layers
|
||||
encoder = FlaxRobertaEncoder(
|
||||
self.num_encoder_layers,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.intermediate_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
name="encoder",
|
||||
dtype=self.dtype,
|
||||
)(embeddings, attention_mask, deterministic=deterministic)
|
||||
|
||||
pooled = FlaxRobertaPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder)
|
||||
return encoder, pooled
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class FlaxRobertaModel(FlaxPreTrainedModel):
|
||||
class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
|
||||
"""
|
||||
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
||||
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
||||
all you need`_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
|
||||
Kaiser and Illia Polosukhin.
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
model_class = FlaxRobertaModule
|
||||
config_class = RobertaConfig
|
||||
base_model_prefix = "roberta"
|
||||
|
||||
|
|
@ -504,7 +477,49 @@ class FlaxRobertaModel(FlaxPreTrainedModel):
|
|||
|
||||
return jax_state
|
||||
|
||||
def __init__(self, config: RobertaConfig, state: dict, seed: int = 0, dtype: jnp.dtype = jnp.float32):
|
||||
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
jnp.zeros(input_shape, dtype="i4"), None, None, None
|
||||
)
|
||||
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
|
||||
|
||||
def _check_inputs(self, input_ids, attention_mask, token_type_ids, position_ids):
|
||||
if token_type_ids is None:
|
||||
token_type_ids = jnp.ones_like(input_ids)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
|
||||
return input_ids, attention_mask, token_type_ids, position_ids
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
|
||||
"""
|
||||
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
||||
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
||||
all you need`_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
|
||||
Kaiser and Illia Polosukhin.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: RobertaConfig,
|
||||
input_shape: Tuple = (1, 1),
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
**kwargs
|
||||
):
|
||||
module = FlaxRobertaModule(
|
||||
vocab_size=config.vocab_size,
|
||||
hidden_size=config.hidden_size,
|
||||
|
|
@ -513,12 +528,14 @@ class FlaxRobertaModel(FlaxPreTrainedModel):
|
|||
num_encoder_layers=config.num_hidden_layers,
|
||||
num_heads=config.num_attention_heads,
|
||||
head_size=config.hidden_size,
|
||||
hidden_act=config.hidden_act,
|
||||
intermediate_size=config.intermediate_size,
|
||||
dropout_rate=config.hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
super().__init__(config, module, state, seed)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
|
|
@ -550,42 +567,53 @@ class FlaxRobertaModel(FlaxPreTrainedModel):
|
|||
rngs=rngs,
|
||||
)
|
||||
|
||||
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple):
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
jnp.zeros(input_shape, dtype="i4"), None, None, None
|
||||
)
|
||||
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta
|
||||
class FlaxRobertaModule(nn.Module):
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
type_vocab_size: int
|
||||
max_length: int
|
||||
num_encoder_layers: int
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
hidden_act: str = "gelu"
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
add_pooling_layer: bool = True
|
||||
|
||||
self.params = self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
|
||||
@nn.compact
|
||||
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
|
||||
|
||||
def _check_inputs(self, input_ids, attention_mask, token_type_ids, position_ids):
|
||||
# Embedding
|
||||
embeddings = FlaxRobertaEmbeddings(
|
||||
self.vocab_size,
|
||||
self.hidden_size,
|
||||
self.type_vocab_size,
|
||||
self.max_length,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
name="embeddings",
|
||||
dtype=self.dtype,
|
||||
)(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic)
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = jnp.ones_like(input_ids)
|
||||
# N stacked encoding layers
|
||||
encoder = FlaxRobertaEncoder(
|
||||
self.num_encoder_layers,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.intermediate_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
hidden_act=self.hidden_act,
|
||||
name="encoder",
|
||||
dtype=self.dtype,
|
||||
)(embeddings, attention_mask, deterministic=deterministic)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
|
||||
if not self.add_pooling_layer:
|
||||
return encoder
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
|
||||
return input_ids, attention_mask, token_type_ids, position_ids
|
||||
|
||||
|
||||
def create_position_ids_from_input_ids(input_ids, padding_idx):
|
||||
"""
|
||||
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
|
||||
are ignored. This is modified from fairseq's `utils.make_positions`.
|
||||
|
||||
Args:
|
||||
input_ids: jnp.ndarray
|
||||
padding_idx: int
|
||||
|
||||
Returns: jnp.ndarray
|
||||
"""
|
||||
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
||||
mask = (input_ids != padding_idx).astype("i4")
|
||||
incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask
|
||||
return incremental_indices.astype("i4") + padding_idx
|
||||
pooled = FlaxRobertaPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder)
|
||||
return encoder, pooled
|
||||
|
|
|
|||
|
|
@ -2,6 +2,15 @@
|
|||
from ..file_utils import requires_flax
|
||||
|
||||
|
||||
class FlaxPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
||||
|
||||
FLAX_MODEL_MAPPING = None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -14,14 +14,16 @@
|
|||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import BertConfig, is_flax_available
|
||||
from transformers.testing_utils import require_flax
|
||||
from transformers.testing_utils import require_flax, slow
|
||||
|
||||
from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
from transformers.models.bert.modeling_flax_bert import FlaxBertModel
|
||||
from transformers.models.bert.modeling_flax_bert import FlaxBertForMaskedLM, FlaxBertModel
|
||||
|
||||
|
||||
class FlaxBertModelTester(unittest.TestCase):
|
||||
|
|
@ -105,7 +107,14 @@ class FlaxBertModelTester(unittest.TestCase):
|
|||
@require_flax
|
||||
class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (FlaxBertModel,) if is_flax_available() else ()
|
||||
all_model_classes = (FlaxBertModel, FlaxBertForMaskedLM) if is_flax_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = FlaxBertModelTester(self)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
model = model_class_name.from_pretrained("bert-base-cased")
|
||||
outputs = model(np.ones((1, 1)))
|
||||
self.assertIsNotNone(outputs)
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
|
@ -26,7 +27,7 @@ if is_flax_available():
|
|||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.traverse_util import unflatten_dict
|
||||
from transformers.modeling_flax_utils import convert_state_dict_from_pt
|
||||
|
||||
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
||||
|
||||
|
|
@ -59,21 +60,13 @@ def random_attention_mask(shape, rng=None):
|
|||
return attn_mask
|
||||
|
||||
|
||||
def convert_pt_model_to_flax(pt_model, config, flax_model_cls):
|
||||
state = pt_model.state_dict()
|
||||
state = {k: v.numpy() for k, v in state.items()}
|
||||
state = flax_model_cls.convert_from_pytorch(state, config)
|
||||
state = unflatten_dict({tuple(k.split(".")): v for k, v in state.items()})
|
||||
return flax_model_cls(config, state, dtype=jnp.float32)
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxModelTesterMixin:
|
||||
model_tester = None
|
||||
all_model_classes = ()
|
||||
|
||||
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
|
||||
diff = np.abs((a - b)).sum()
|
||||
diff = np.abs((a - b)).max()
|
||||
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
||||
|
||||
@require_torch
|
||||
|
|
@ -86,30 +79,54 @@ class FlaxModelTesterMixin:
|
|||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
pt_model = pt_model_class(config).eval()
|
||||
|
||||
fx_model = convert_pt_model_to_flax(pt_model, config, model_class)
|
||||
fx_state = convert_state_dict_from_pt(model_class, pt_model.state_dict(), config)
|
||||
fx_model = model_class(config, dtype=jnp.float32)
|
||||
fx_model.params = fx_state
|
||||
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in inputs_dict.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**inputs_dict)
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-3)
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**inputs_dict)
|
||||
self.assertEqual(
|
||||
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 5e-3)
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
model = model_class(config)
|
||||
|
||||
outputs = model(**inputs_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_loaded = model_class.from_pretrained(tmpdirname)
|
||||
|
||||
outputs_loaded = model_loaded(**inputs_dict)
|
||||
for output_loaded, output in zip(outputs_loaded, outputs):
|
||||
self.assert_almost_equals(output_loaded, output, 5e-3)
|
||||
|
||||
@require_torch
|
||||
def test_jit_compilation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
|
||||
# TODO later: have some way to initialize easily a Flax model from config, for now I go through PT
|
||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
pt_model = pt_model_class(config).eval()
|
||||
|
||||
model = convert_pt_model_to_flax(pt_model, config, model_class)
|
||||
model = model_class(config)
|
||||
|
||||
@jax.jit
|
||||
def model_jitted(input_ids, attention_mask=None, token_type_ids=None):
|
||||
|
|
@ -125,3 +142,14 @@ class FlaxModelTesterMixin:
|
|||
self.assertEqual(len(outputs), len(jitted_outputs))
|
||||
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||
self.assertEqual(jitted_output.shape, output.shape)
|
||||
|
||||
def test_naming_convention(self):
|
||||
for model_class in self.all_model_classes:
|
||||
model_class_name = model_class.__name__
|
||||
module_class_name = (
|
||||
model_class_name[:-5] + "Module" if model_class_name[-5:] == "Model" else model_class_name + "Module"
|
||||
)
|
||||
bert_modeling_flax_module = __import__(model_class.__module__, fromlist=[module_class_name])
|
||||
module_cls = getattr(bert_modeling_flax_module, module_class_name)
|
||||
|
||||
self.assertIsNotNone(module_cls)
|
||||
|
|
|
|||
|
|
@ -14,8 +14,10 @@
|
|||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import RobertaConfig, is_flax_available
|
||||
from transformers.testing_utils import require_flax
|
||||
from transformers.testing_utils import require_flax, slow
|
||||
|
||||
from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
|
@ -109,3 +111,10 @@ class FlaxRobertaModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
|||
|
||||
def setUp(self):
|
||||
self.model_tester = FlaxRobertaModelTester(self)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
model = model_class_name.from_pretrained("roberta-base")
|
||||
outputs = model(np.ones((1, 1)))
|
||||
self.assertIsNotNone(outputs)
|
||||
|
|
|
|||
Loading…
Reference in a new issue