From 640e6fe19062bb722f06dc3303ca2b6104de367d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 16 Dec 2020 13:03:32 +0100 Subject: [PATCH] [Flax] Align FlaxBertForMaskedLM with BertForMaskedLM, implement from_pretrained, init (#9054) * save intermediate * save intermediate * save intermediate * correct flax bert model file * new module / model naming * make style * almost finish BERT * finish roberta * make fix-copies * delete keys file * last refactor * fixes in run_mlm_flax.py * remove pooled from run_mlm_flax.py` * fix gelu | gelu_new * remove Module from inits * splits * dirty print * preventing warmup_steps == 0 * smaller splits * make fix-copies * dirty print * dirty print * initial_evaluation argument * declaration order fix * proper model initialization/loading * proper initialization * run_mlm_flax improvements: improper model inputs bugfix + automatic dataset splitting + tokenizers parallelism warning + avoiding warmup_steps=0 bug * removed tokenizers warning hack, fixed model re-initialization * reverted training_args.py changes * fix flax from pretrained * improve test in flax * apply sylvains tips * update init * make 0.3.0 compatible * revert tevens changes * revert tevens changes 2 * finalize revert * fix bug * add docs * add pretrained to init * Update src/transformers/modeling_flax_utils.py * fix copies * final improvements Co-authored-by: TevenLeScao --- docs/source/main_classes/model.rst | 14 +- examples/language-modeling/run_mlm_flax.py | 13 +- setup.py | 14 +- src/transformers/__init__.py | 1 + src/transformers/dependency_versions_table.py | 6 +- src/transformers/file_utils.py | 1 + src/transformers/modeling_flax_utils.py | 307 +++++++++++++--- .../models/auto/modeling_flax_auto.py | 23 +- .../models/bert/modeling_flax_bert.py | 329 +++++++++++------- .../models/roberta/modeling_flax_roberta.py | 248 +++++++------ src/transformers/utils/dummy_flax_objects.py | 9 + tests/test_modeling_flax_bert.py | 15 +- tests/test_modeling_flax_common.py | 68 ++-- tests/test_modeling_flax_roberta.py | 11 +- 14 files changed, 700 insertions(+), 359 deletions(-) diff --git a/docs/source/main_classes/model.rst b/docs/source/main_classes/model.rst index f3ded6d61..fef1426fa 100644 --- a/docs/source/main_classes/model.rst +++ b/docs/source/main_classes/model.rst @@ -13,9 +13,10 @@ Models ----------------------------------------------------------------------------------------------------------------------- -The base classes :class:`~transformers.PreTrainedModel` and :class:`~transformers.TFPreTrainedModel` implement the -common methods for loading/saving a model either from a local file or directory, or from a pretrained model -configuration provided by the library (downloaded from HuggingFace's AWS S3 repository). +The base classes :class:`~transformers.PreTrainedModel`, :class:`~transformers.TFPreTrainedModel`, and +:class:`~transformers.FlaxPreTrainedModel` implement the common methods for loading/saving a model either from a local +file or directory, or from a pretrained model configuration provided by the library (downloaded from HuggingFace's AWS +S3 repository). :class:`~transformers.PreTrainedModel` and :class:`~transformers.TFPreTrainedModel` also implement a few methods which are common among all the models to: @@ -57,6 +58,13 @@ TFModelUtilsMixin :members: +FlaxPreTrainedModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxPreTrainedModel + :members: + + Generation ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/examples/language-modeling/run_mlm_flax.py b/examples/language-modeling/run_mlm_flax.py index 5fe4aefce..83bf9bffe 100644 --- a/examples/language-modeling/run_mlm_flax.py +++ b/examples/language-modeling/run_mlm_flax.py @@ -385,7 +385,7 @@ def training_step(optimizer, batch, dropout_rng): # Hide away tokens which doesn't participate in the optimization token_mask = jnp.where(targets > 0, 1.0, 0.0) - pooled, logits = model(**batch, params=params, dropout_rng=dropout_rng, train=True) + logits = model(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss, weight_sum = cross_entropy(logits, targets, token_mask) return loss / weight_sum @@ -407,7 +407,7 @@ def eval_step(params, batch): # Hide away tokens which doesn't participate in the optimization token_mask = jnp.where(targets > 0, 1.0, 0.0) - _, logits = model(**batch, params=params, train=False) + logits = model(**batch, params=params, train=False)[0] return compute_metrics(logits, targets, token_mask) @@ -572,8 +572,13 @@ if __name__ == "__main__": rng = jax.random.PRNGKey(training_args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) - model = FlaxBertForMaskedLM.from_pretrained("bert-base-cased", dtype=jnp.float32, dropout_rate=0.1) - model.init(jax.random.PRNGKey(training_args.seed), (training_args.train_batch_size, model.config.max_length)) + model = FlaxBertForMaskedLM.from_pretrained( + "bert-base-cased", + dtype=jnp.float32, + input_shape=(training_args.train_batch_size, config.max_position_embeddings), + seed=training_args.seed, + dropout_rate=0.1, + ) # Setup optimizer optimizer = Adam( diff --git a/setup.py b/setup.py index 75497883b..16e9b5927 100644 --- a/setup.py +++ b/setup.py @@ -97,7 +97,7 @@ _deps = [ "fastapi", "filelock", "flake8>=3.8.3", - "flax==0.2.2", + "flax>=0.2.2", "fugashi>=1.0", "ipadic>=1.0.0,<2.0", "isort>=5.5.4", @@ -175,7 +175,7 @@ class DepsTableUpdateCommand(Command): "deps = {", entries, "}", - "" + "", ] target = "src/transformers/dependency_versions_table.py" print(f"updating {target}") @@ -232,14 +232,14 @@ extras["dev"] = ( # when modifying the following list, make sure to update src/transformers/dependency_versions_check.py install_requires = [ deps["dataclasses"] + ";python_version<'3.7'", # dataclasses for Python versions that don't have it - deps["filelock"], # filesystem locks, e.g., to prevent parallel downloads + deps["filelock"], # filesystem locks, e.g., to prevent parallel downloads deps["numpy"], deps["packaging"], # utilities from PyPA to e.g., compare versions - deps["regex"], # for OpenAI GPT - deps["requests"], # for downloading models over HTTPS - deps["sacremoses"], # for XLM + deps["regex"], # for OpenAI GPT + deps["requests"], # for downloading models over HTTPS + deps["sacremoses"], # for XLM deps["tokenizers"], - deps["tqdm"], # progress bars in model download and training scripts + deps["tqdm"], # progress bars in model download and training scripts ] setup( diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 7311ee04c..1cf6bdd88 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -945,6 +945,7 @@ else: if is_flax_available(): + from .modeling_flax_utils import FlaxPreTrainedModel from .models.auto import FLAX_MODEL_MAPPING, FlaxAutoModel from .models.bert import FlaxBertForMaskedLM, FlaxBertModel from .models.roberta import FlaxRobertaModel diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index fec88c778..39cf09256 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -10,7 +10,7 @@ deps = { "fastapi": "fastapi", "filelock": "filelock", "flake8": "flake8>=3.8.3", - "flax": "flax==0.2.2", + "flax": "flax>=0.2.2", "fugashi": "fugashi>=1.0", "ipadic": "ipadic>=1.0.0,<2.0", "isort": "isort>=5.5.4", @@ -40,8 +40,8 @@ deps = { "sphinx-rtd-theme": "sphinx-rtd-theme==0.4.3", "sphinx": "sphinx==3.2.1", "starlette": "starlette", - "tensorflow-cpu": "tensorflow-cpu>=2.0", - "tensorflow": "tensorflow>=2.0", + "tensorflow-cpu": "tensorflow-cpu>=2.0,<2.4", + "tensorflow": "tensorflow>=2.0,<2.4", "timeout-decorator": "timeout-decorator", "tokenizers": "tokenizers==0.9.4", "torch": "torch>=1.0", diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 0e345bb28..420a6ff21 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -270,6 +270,7 @@ TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE) WEIGHTS_NAME = "pytorch_model.bin" TF2_WEIGHTS_NAME = "tf_model.h5" TF_WEIGHTS_NAME = "model.ckpt" +FLAX_WEIGHTS_NAME = "flax_model.msgpack" CONFIG_NAME = "config.json" MODEL_CARD_NAME = "modelcard.json" diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 69bf948b0..0823ddf1b 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -15,64 +15,65 @@ import os from abc import ABC, abstractmethod +from functools import partial from pickle import UnpicklingError -from typing import Dict +from typing import Dict, Set, Tuple, Union import flax.linen as nn import jax import jax.numpy as jnp -from flax.serialization import to_bytes -from flax.traverse_util import unflatten_dict +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.serialization import from_bytes, to_bytes +from flax.traverse_util import flatten_dict, unflatten_dict from jax.random import PRNGKey from .configuration_utils import PretrainedConfig -from .file_utils import WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url +from .file_utils import FLAX_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url from .utils import logging logger = logging.get_logger(__name__) -@jax.jit -def gelu(x): - r""" - Gaussian error linear unit activation function. - - Computes the element-wise function: - - .. math:: - \mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left( - \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right) - - We explicitly use the approximation rather than the exact formulation for speed. For more information, see - `Gaussian Error Linear Units (GELUs) `_, section 2. - """ - return x * 0.5 * (1.0 + jax.lax.erf(x / jnp.sqrt(2.0))) - - ACT2FN = { "gelu": nn.gelu, "relu": nn.relu, "silu": nn.swish, "swish": nn.swish, - "gelu_new": gelu, + "gelu_new": partial(nn.gelu, approximate=True), } class FlaxPreTrainedModel(ABC): + r""" + Base class for all models. + + :class:`~transformers.FlaxPreTrainedModel` takes care of storing the configuration of the models and handles + methods for loading, downloading and saving models. + + Class attributes (overridden by derived classes): + + - **config_class** (:class:`~transformers.PretrainedConfig`) -- A subclass of + :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture. + - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in + derived classes of the same architecture adding modules on top of the base model. + """ config_class = None - pretrained_model_archive_map = {} base_model_prefix = "" - model_class = None def __init__( - self, config: PretrainedConfig, module: nn.Module, params: Dict, seed: int = 0, dtype: jnp.dtype = jnp.float32 + self, + config: PretrainedConfig, + module: nn.Module, + input_shape: Tuple = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, ): if config is None: raise ValueError("config cannot be None") - if params is None: - raise ValueError("state cannot be None") + if module is None: + raise ValueError("module cannot be None") # Those are private to be exposed as typed property on derived classes. self._config = config @@ -80,9 +81,18 @@ class FlaxPreTrainedModel(ABC): # Those are public as their type is generic to every derived classes. self.key = PRNGKey(seed) - self.params = params self.dtype = dtype + # randomely initialized parameters + random_params = self.init(self.key, input_shape) + + # save required_params as set + self._required_params = set(flatten_dict(unfreeze(random_params)).keys()) + self.params = random_params + + def init(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> Dict: + raise NotImplementedError(f"init method has to be implemented for {self}") + @property def config(self) -> PretrainedConfig: return self._config @@ -91,24 +101,130 @@ class FlaxPreTrainedModel(ABC): def module(self) -> nn.Module: return self._module + @property + def params(self) -> Union[Dict, FrozenDict]: + return self._params + + @property + def required_params(self) -> Set: + return self._required_params + + @params.setter + def params(self, params: Union[Dict, FrozenDict]): + if isinstance(params, FrozenDict): + params = unfreeze(params) + param_keys = set(flatten_dict(params).keys()) + if len(self.required_params - param_keys) > 0: + raise ValueError( + "Some parameters are missing. Make sure that `params` include the following " + f"parameters {self.required_params - param_keys}" + ) + self._params = freeze(params) + @staticmethod @abstractmethod def convert_from_pytorch(pt_state: Dict, config: PretrainedConfig) -> Dict: raise NotImplementedError() @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, dtype: jnp.dtype = jnp.float32, *model_args, **kwargs): + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + dtype: jnp.dtype = jnp.float32, + *model_args, + **kwargs + ): + r""" - Instantiate a pretrained Flax model from a pre-trained model configuration. + Instantiate a pretrained flax model from a pre-trained model configuration. + + The warning `Weights from XXX not initialized from pretrained model` means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning `Weights from XXX not used in YYY` means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): + Can be either: + + - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under + a user or organization name, like ``dbmdz/bert-base-german-cased``. + - A path to a `directory` containing model weights saved using + :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. + - A path or url to a `pt index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In this + case, ``from_pt`` should be set to :obj:`True`. + model_args (sequence of positional arguments, `optional`): + All remaning positional arguments will be passed to the underlying model's ``__init__`` method. + config (:obj:`Union[PretrainedConfig, str, os.PathLike]`, `optional`): + Can be either: + + - an instance of a class derived from :class:`~transformers.PretrainedConfig`, + - a string or path valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`. + + Configuration for the model to use instead of an automatically loaded configuation. Configuration can + be automatically loaded when: + + - The model is a model provided by the library (loaded with the `model id` string of a pretrained + model). + - The model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded + by supplying the save directory. + - The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a + configuration JSON file named `config.json` is found in the directory. + cache_dir (:obj:`Union[str, os.PathLike]`, `optional`): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + from_pt (:obj:`bool`, `optional`, defaults to :obj:`False`): + Load the model weights from a PyTorch checkpoint save file (see docstring of + ``pretrained_model_name_or_path`` argument). + force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (:obj:`Dict[str, str], `optional`): + A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to only look at local files (i.e., do not try to download the model). + revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any + identifier allowed by git. + kwargs (remaining dictionary of keyword arguments, `optional`): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or + automatically loaded: + + - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the + underlying model's ``__init__`` method (we assume all relevant updates to the configuration have + already been done) + - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class + initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of + ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute + with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration + attribute will be passed to the underlying model's ``__init__`` function. + + Examples:: + + >>> from transformers import BertConfig, FlaxBertModel + >>> # Download model and configuration from huggingface.co and cache. + >>> model = FlaxBertModel.from_pretrained('bert-base-cased') + >>> # Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable). + >>> model = FlaxBertModel.from_pretrained('./test/saved_model/') + >>> # Loading from a PyTorch checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable). + >>> config = BertConfig.from_json_file('./pt_model/config.json') + >>> model = FlaxBertModel.from_pretrained('./pt_model/pytorch_model.bin', from_pt=True, config=config) """ config = kwargs.pop("config", None) - # state_dict = kwargs.pop("state_dict", None) cache_dir = kwargs.pop("cache_dir", None) - # from_tf = kwargs.pop("from_tf", False) + from_pt = kwargs.pop("from_pt", False) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) - # output_loading_info = kwargs.pop("output_loading_info", False) local_files_only = kwargs.pop("local_files_only", False) revision = kwargs.pop("revision", None) @@ -135,10 +251,28 @@ class FlaxPreTrainedModel(ABC): # Load model if pretrained_model_name_or_path is not None: - if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): + if os.path.isdir(pretrained_model_name_or_path): + if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): + # Load from a PyTorch checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)): + # Load from a Flax checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) + else: + raise EnvironmentError( + "Error no file named {} found in directory {} or `from_pt` set to False".format( + [FLAX_WEIGHTS_NAME, WEIGHTS_NAME], + pretrained_model_name_or_path, + ) + ) + elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path else: - archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision) + archive_file = hf_bucket_url( + pretrained_model_name_or_path, + filename=WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME, + revision=revision, + ) # redirect to the cache, if necessary try: @@ -169,31 +303,96 @@ class FlaxPreTrainedModel(ABC): # Instantiate model. with open(resolved_archive_file, "rb") as state_f: try: - from flax.serialization import from_bytes - - state = from_bytes(cls.model_class, state_f) - except TypeError: - try: + if from_pt: import torch state = torch.load(state_f) - state = {k: v.numpy() for k, v in state.items()} - state = cls.convert_from_pytorch(state, config) - state = unflatten_dict({tuple(k.split(".")[1:]): v for k, v in state.items()}) - except UnpicklingError: - raise EnvironmentError( - f"Unable to convert model {archive_file} to Flax deserializable object. " - "Supported format are PyTorch archive or Flax msgpack" - ) - return cls(config, state, *model_args, **model_kwargs) + state = convert_state_dict_from_pt(cls, state, config) + else: + state = from_bytes(cls, state_f.read()) + except UnpicklingError: + raise EnvironmentError( + f"Unable to convert pytorch model {archive_file} to Flax deserializable object. " + ) - def save_pretrained(self, folder): - folder_abs = os.path.abspath(folder) + # init random models + model = cls(config, *model_args, **model_kwargs) - if not os.path.exists(folder_abs): - os.mkdir(folder_abs) + # if model is base model only use model_prefix key + if cls.base_model_prefix not in dict(model.params) and cls.base_model_prefix in state: + state = state[cls.base_model_prefix] - with open(os.path.join(folder_abs, f"{self._config.model_type}.flax", "wb")) as f: + # flatten dicts + state = flatten_dict(state) + random_state = flatten_dict(unfreeze(model.params)) + + missing_keys = model.required_params - set(state.keys()) + unexpected_keys = set(state.keys()) - model.required_params + + # add missing keys as random parameters + for missing_key in missing_keys: + state[missing_key] = random_state[missing_key] + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " + f"initializing {model.__class__.__name__}: {unexpected_keys}\n" + f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " + f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n" + f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect " + f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " + f"and are newly initialized: {missing_keys}\n" + f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + else: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n" + f"If your task is similar to the task the model of the checkpoint was trained on, " + f"you can already use {model.__class__.__name__} for predictions without further training." + ) + + # set correct parameters + model.params = unflatten_dict(state) + return model + + def save_pretrained(self, save_directory: Union[str, os.PathLike]): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + `:func:`~transformers.FlaxPreTrainedModel.from_pretrained`` class method + + Arguments: + save_directory (:obj:`str` or :obj:`os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + """ + if os.path.isfile(save_directory): + logger.error("Provided path ({}) should be a directory, not a file".format(save_directory)) + return + os.makedirs(save_directory, exist_ok=True) + + # get abs dir + save_directory = os.path.abspath(save_directory) + # save config as well + self.config.save_pretrained(save_directory) + + # save model + with open(os.path.join(save_directory, FLAX_WEIGHTS_NAME), "wb") as f: model_bytes = to_bytes(self.params) f.write(model_bytes) + + +def convert_state_dict_from_pt(model_class: ABC, state: Dict, config: PretrainedConfig): + """ + Converts a PyTorch parameter state dict to an equivalent Flax parameter state dict + """ + state = {k: v.numpy() for k, v in state.items()} + state = model_class.convert_from_pytorch(state, config) + state = unflatten_dict({tuple(k.split(".")): v for k, v in state.items()}) + return state diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index dab92814a..0a65f332c 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -27,15 +27,6 @@ from .configuration_auto import AutoConfig, BertConfig, RobertaConfig logger = logging.get_logger(__name__) -ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict( - (key, value) - for pretrained_map in [ - FlaxBertModel.pretrained_model_archive_map, - FlaxRobertaModel.pretrained_model_archive_map, - ] - for key, value, in pretrained_map.items() -) - FLAX_MODEL_MAPPING = OrderedDict( [ (RobertaConfig, FlaxRobertaModel), @@ -114,10 +105,9 @@ class FlaxAutoModel(object): organization name, like ``dbmdz/bert-base-german-cased``. - a path to a `directory` containing model weights saved using :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. - - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this - case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` - argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model - using the provided conversion scripts and loading the PyTorch model afterwards. + - a path or url to a `pytorch index checkpoint file` (e.g. `./pt_model/pytorch_model.bin`). In this + case, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` + argument. model_args: (`optional`) Sequence of positional arguments: All remaining positional arguments will be passed to the underlying model's ``__init__`` method @@ -133,13 +123,6 @@ class FlaxAutoModel(object): - the model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. - state_dict: (`optional`) dict: - an optional state dictionary for the model to use instead of a state dictionary loaded from saved - weights file. This option can be used if you want to create a model from a pretrained configuration but - load your own weights. In this case though, you should check if using - :func:`~transformers.FlaxPreTrainedModel.save_pretrained` and - :func:`~transformers.FlaxPreTrainedModel.from_pretrained` is not a simpler option. - cache_dir: (`optional`) string: Path to a directory in which a downloaded pre-trained model configuration should be cached if the standard cache should not be used. diff --git a/src/transformers/models/bert/modeling_flax_bert.py b/src/transformers/models/bert/modeling_flax_bert.py index f0f39d975..9def58dad 100644 --- a/src/transformers/models/bert/modeling_flax_bert.py +++ b/src/transformers/models/bert/modeling_flax_bert.py @@ -20,10 +20,11 @@ import numpy as np import flax.linen as nn import jax import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict from jax.random import PRNGKey from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward -from ...modeling_flax_utils import FlaxPreTrainedModel, gelu +from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel from ...utils import logging from .configuration_bert import BertConfig @@ -205,7 +206,7 @@ class FlaxBertAttention(nn.Module): dtype: jnp.dtype = jnp.float32 # the dtype of the computation @nn.compact - def __call__(self, hidden_state, attention_mask, deterministic: bool = True): + def __call__(self, hidden_states, attention_mask, deterministic: bool = True): # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) @@ -219,27 +220,28 @@ class FlaxBertAttention(nn.Module): bias_init=jax.nn.initializers.zeros, name="self", dtype=self.dtype, - )(hidden_state, attention_mask) + )(hidden_states, attention_mask) - layer_norm = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_state) + layer_norm = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_states) return layer_norm class FlaxBertIntermediate(nn.Module): output_size: int + hidden_act: str = "gelu" kernel_init_scale: float = 0.2 dtype: jnp.dtype = jnp.float32 # the dtype of the computation @nn.compact - def __call__(self, hidden_state): - # TODO: Add ACT2FN reference to change activation function - dense = nn.Dense( + def __call__(self, hidden_states): + hidden_states = nn.Dense( features=self.output_size, kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype), name="dense", dtype=self.dtype, - )(hidden_state) - return gelu(dense) + )(hidden_states) + hidden_states = ACT2FN[self.hidden_act](hidden_states) + return hidden_states class FlaxBertOutput(nn.Module): @@ -249,27 +251,28 @@ class FlaxBertOutput(nn.Module): @nn.compact def __call__(self, intermediate_output, attention_output, deterministic: bool = True): - hidden_state = nn.Dense( + hidden_states = nn.Dense( attention_output.shape[-1], kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype), name="dense", dtype=self.dtype, )(intermediate_output) - hidden_state = nn.Dropout(rate=self.dropout_rate)(hidden_state, deterministic=deterministic) - hidden_state = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_state + attention_output) - return hidden_state + hidden_states = nn.Dropout(rate=self.dropout_rate)(hidden_states, deterministic=deterministic) + hidden_states = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_states + attention_output) + return hidden_states class FlaxBertLayer(nn.Module): num_heads: int head_size: int intermediate_size: int + hidden_act: str = "gelu" dropout_rate: float = 0.0 kernel_init_scale: float = 0.2 dtype: jnp.dtype = jnp.float32 # the dtype of the computation @nn.compact - def __call__(self, hidden_state, attention_mask, deterministic: bool = True): + def __call__(self, hidden_states, attention_mask, deterministic: bool = True): attention = FlaxBertAttention( self.num_heads, self.head_size, @@ -277,9 +280,13 @@ class FlaxBertLayer(nn.Module): dropout_rate=self.dropout_rate, name="attention", dtype=self.dtype, - )(hidden_state, attention_mask, deterministic=deterministic) + )(hidden_states, attention_mask, deterministic=deterministic) intermediate = FlaxBertIntermediate( - self.intermediate_size, kernel_init_scale=self.kernel_init_scale, name="intermediate", dtype=self.dtype + self.intermediate_size, + kernel_init_scale=self.kernel_init_scale, + hidden_act=self.hidden_act, + name="intermediate", + dtype=self.dtype, )(attention) output = FlaxBertOutput( kernel_init_scale=self.kernel_init_scale, dropout_rate=self.dropout_rate, name="output", dtype=self.dtype @@ -297,6 +304,7 @@ class FlaxBertLayerCollection(nn.Module): num_heads: int head_size: int intermediate_size: int + hidden_act: str = "gelu" dropout_rate: float = 0.0 kernel_init_scale: float = 0.2 dtype: jnp.dtype = jnp.float32 # the dtype of the computation @@ -316,6 +324,7 @@ class FlaxBertLayerCollection(nn.Module): self.intermediate_size, kernel_init_scale=self.kernel_init_scale, dropout_rate=self.dropout_rate, + hidden_act=self.hidden_act, name=f"{i}", dtype=self.dtype, ) @@ -328,22 +337,24 @@ class FlaxBertEncoder(nn.Module): num_heads: int head_size: int intermediate_size: int + hidden_act: str = "gelu" dropout_rate: float = 0.0 kernel_init_scale: float = 0.2 dtype: jnp.dtype = jnp.float32 # the dtype of the computation @nn.compact - def __call__(self, hidden_state, attention_mask, deterministic: bool = True): + def __call__(self, hidden_states, attention_mask, deterministic: bool = True): layer = FlaxBertLayerCollection( self.num_layers, self.num_heads, self.head_size, self.intermediate_size, + hidden_act=self.hidden_act, kernel_init_scale=self.kernel_init_scale, dropout_rate=self.dropout_rate, name="layer", dtype=self.dtype, - )(hidden_state, attention_mask, deterministic=deterministic) + )(hidden_states, attention_mask, deterministic=deterministic) return layer @@ -352,10 +363,10 @@ class FlaxBertPooler(nn.Module): dtype: jnp.dtype = jnp.float32 # the dtype of the computation @nn.compact - def __call__(self, hidden_state): - cls_token = hidden_state[:, 0] + def __call__(self, hidden_states): + cls_token = hidden_states[:, 0] out = nn.Dense( - hidden_state.shape[-1], + hidden_states.shape[-1], kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype), name="dense", dtype=self.dtype, @@ -363,62 +374,20 @@ class FlaxBertPooler(nn.Module): return nn.tanh(out) -class FlaxBertModule(nn.Module): - vocab_size: int - hidden_size: int - type_vocab_size: int - max_length: int - num_encoder_layers: int - num_heads: int - head_size: int - intermediate_size: int - dropout_rate: float = 0.0 - kernel_init_scale: float = 0.2 - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - @nn.compact - def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True): - - # Embedding - embeddings = FlaxBertEmbeddings( - self.vocab_size, - self.hidden_size, - self.type_vocab_size, - self.max_length, - kernel_init_scale=self.kernel_init_scale, - dropout_rate=self.dropout_rate, - name="embeddings", - dtype=self.dtype, - )(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic) - - # N stacked encoding layers - encoder = FlaxBertEncoder( - self.num_encoder_layers, - self.num_heads, - self.head_size, - self.intermediate_size, - kernel_init_scale=self.kernel_init_scale, - dropout_rate=self.dropout_rate, - name="encoder", - dtype=self.dtype, - )(embeddings, attention_mask, deterministic=deterministic) - - pooled = FlaxBertPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder) - return encoder, pooled - - class FlaxBertPredictionHeadTransform(nn.Module): + hidden_act: str = "gelu" dtype: jnp.dtype = jnp.float32 @nn.compact def __call__(self, hidden_states): hidden_states = nn.Dense(hidden_states.shape[-1], name="dense", dtype=self.dtype)(hidden_states) - hidden_states = nn.elu(hidden_states) # TODO: ACT2FN[config.hidden_act] - return FlaxBertLayerNorm(name="LayerNorm", dtype=self.dtype)(hidden_states) + hidden_states = ACT2FN[self.hidden_act](hidden_states) + return FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_states) class FlaxBertLMPredictionHead(nn.Module): vocab_size: int + hidden_act: str = "gelu" dtype: jnp.dtype = jnp.float32 @nn.compact @@ -428,64 +397,57 @@ class FlaxBertLMPredictionHead(nn.Module): # Need a link between the two variables so that the bias is correctly # resized with `resize_token_embeddings` - hidden_states = FlaxBertPredictionHeadTransform(name="transform", dtype=self.dtype)(hidden_states) + hidden_states = FlaxBertPredictionHeadTransform( + name="transform", hidden_act=self.hidden_act, dtype=self.dtype + )(hidden_states) hidden_states = nn.Dense(self.vocab_size, name="decoder", dtype=self.dtype)(hidden_states) return hidden_states class FlaxBertOnlyMLMHead(nn.Module): vocab_size: int - hidden_size: int - intermediate_size: int - head_size: int - num_heads: int - num_encoder_layers: int - type_vocab_size: int - max_length: int - dropout_rate: float = 0.0 + hidden_act: str = "gelu" dtype: jnp.dtype = jnp.float32 @nn.compact - def __call__( - self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True - ): - # Model - encoder, pooled = FlaxBertModule( - vocab_size=self.vocab_size, - type_vocab_size=self.type_vocab_size, - hidden_size=self.hidden_size, - intermediate_size=self.intermediate_size, - head_size=self.hidden_size, - num_heads=self.num_heads, - num_encoder_layers=self.num_encoder_layers, - max_length=self.max_length, - dropout_rate=self.dropout_rate, - dtype=self.dtype, - )(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic) - - # Compute the prediction scores - encoder = nn.Dropout(rate=self.dropout_rate)(encoder, deterministic=deterministic) - logits = FlaxBertLMPredictionHead(vocab_size=self.vocab_size, name="predictions", dtype=self.dtype)(encoder) - - return logits, pooled + def __call__(self, hidden_states): + hidden_states = FlaxBertLMPredictionHead( + vocab_size=self.vocab_size, hidden_act=self.hidden_act, name="predictions", dtype=self.dtype + )(hidden_states) + return hidden_states -@add_start_docstrings( - "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.", - BERT_START_DOCSTRING, -) -class FlaxBertModel(FlaxPreTrainedModel): +class FlaxBertPreTrainedModel(FlaxPreTrainedModel): """ - The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of - cross-attention is added between the self-attention layers, following the architecture described in `Attention is - all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, - Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. """ - model_class = FlaxBertModule config_class = BertConfig base_model_prefix = "bert" + def _check_inputs(self, input_ids, attention_mask, token_type_ids, position_ids): + if token_type_ids is None: + token_type_ids = jnp.ones_like(input_ids) + + if position_ids is None: + position_ids = jnp.arange(jnp.atleast_2d(input_ids).shape[-1]) + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + return input_ids, attention_mask, token_type_ids, position_ids + + def init(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs( + jnp.zeros(input_shape, dtype="i4"), None, None, None + ) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"] + @staticmethod def convert_from_pytorch(pt_state: Dict, config: BertConfig) -> Dict: jax_state = dict(pt_state) @@ -501,6 +463,11 @@ class FlaxBertModel(FlaxPreTrainedModel): key = key.replace("weight", "kernel") jax_state[key] = tensor + if "decoder.weight" in key: + del jax_state[key] + key = key.replace("weight", "kernel") + jax_state[key] = tensor.T + # SelfAttention needs also to replace "weight" by "kernel" if {"query", "key", "value"} & key_parts: @@ -526,7 +493,7 @@ class FlaxBertModel(FlaxPreTrainedModel): jax_state[key] = tensor # There are some transposed parameters w.r.t their PyTorch counterpart - if "intermediate.dense.kernel" in key or "output.dense.kernel" in key: + if "intermediate.dense.kernel" in key or "output.dense.kernel" in key or "transform.dense.kernel" in key: jax_state[key] = tensor.T # Self Attention output projection needs to be transposed @@ -539,6 +506,11 @@ class FlaxBertModel(FlaxPreTrainedModel): if "pooler.dense.kernel" in key: jax_state[key] = tensor.T + # Hack to correctly load some pytorch models + if "predictions.bias" in key: + del jax_state[key] + jax_state[".".join(key.split(".")[:2]) + ".decoder.bias"] = tensor + # Handle LayerNorm conversion if "LayerNorm" in key: del jax_state[key] @@ -555,7 +527,22 @@ class FlaxBertModel(FlaxPreTrainedModel): return jax_state - def __init__(self, config: BertConfig, state: dict, seed: int = 0, dtype: jnp.dtype = jnp.float32): + +@add_start_docstrings( + "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.", + BERT_START_DOCSTRING, +) +class FlaxBertModel(FlaxBertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + """ + + def __init__( + self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs + ): module = FlaxBertModule( vocab_size=config.vocab_size, hidden_size=config.hidden_size, @@ -566,10 +553,12 @@ class FlaxBertModel(FlaxPreTrainedModel): head_size=config.hidden_size, intermediate_size=config.intermediate_size, dropout_rate=config.hidden_dropout_prob, + hidden_act=config.hidden_act, dtype=dtype, + **kwargs, ) - super().__init__(config, module, state, seed) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) def __call__( @@ -601,34 +590,62 @@ class FlaxBertModel(FlaxPreTrainedModel): rngs=rngs, ) - def _check_inputs(self, input_ids, attention_mask, token_type_ids, position_ids): - if token_type_ids is None: - token_type_ids = jnp.ones_like(input_ids) - if position_ids is None: - position_ids = jnp.arange(jnp.atleast_2d(input_ids).shape[-1]) +class FlaxBertModule(nn.Module): + vocab_size: int + hidden_size: int + type_vocab_size: int + max_length: int + num_encoder_layers: int + num_heads: int + head_size: int + intermediate_size: int + hidden_act: str = "gelu" + dropout_rate: float = 0.0 + kernel_init_scale: float = 0.2 + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + add_pooling_layer: bool = True - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) + @nn.compact + def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True): - return input_ids, attention_mask, token_type_ids, position_ids + # Embedding + embeddings = FlaxBertEmbeddings( + self.vocab_size, + self.hidden_size, + self.type_vocab_size, + self.max_length, + kernel_init_scale=self.kernel_init_scale, + dropout_rate=self.dropout_rate, + name="embeddings", + dtype=self.dtype, + )(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic) - def init(self, rng: jax.random.PRNGKey, input_shape: Tuple): - input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs( - jnp.zeros(input_shape, dtype="i4"), None, None, None - ) + # N stacked encoding layers + encoder = FlaxBertEncoder( + self.num_encoder_layers, + self.num_heads, + self.head_size, + self.intermediate_size, + kernel_init_scale=self.kernel_init_scale, + dropout_rate=self.dropout_rate, + hidden_act=self.hidden_act, + name="encoder", + dtype=self.dtype, + )(embeddings, attention_mask, deterministic=deterministic) - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} + if not self.add_pooling_layer: + return encoder - self.params = self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"] + pooled = FlaxBertPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder) + return encoder, pooled -class FlaxBertForMaskedLM(FlaxBertModel): - def __init__(self, config: BertConfig, state: dict, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs): - super().__init__(config, state, seed, dtype) - - self._module = FlaxBertOnlyMLMHead( +class FlaxBertForMaskedLM(FlaxBertPreTrainedModel): + def __init__( + self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs + ): + module = FlaxBertForMaskedLMModule( vocab_size=config.vocab_size, type_vocab_size=config.type_vocab_size, hidden_size=config.hidden_size, @@ -636,10 +653,13 @@ class FlaxBertForMaskedLM(FlaxBertModel): head_size=config.hidden_size, num_heads=config.num_attention_heads, num_encoder_layers=config.num_hidden_layers, - max_length=config.max_length, + max_length=config.max_position_embeddings, + hidden_act=config.hidden_act, **kwargs, ) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + def __call__( self, input_ids, @@ -659,7 +679,7 @@ class FlaxBertForMaskedLM(FlaxBertModel): if dropout_rng is not None: rngs["dropout"] = dropout_rng - pooled, logits = self.module.apply( + return self.module.apply( {"params": params or self.params}, jnp.array(input_ids, dtype="i4"), jnp.array(attention_mask, dtype="i4"), @@ -669,4 +689,45 @@ class FlaxBertForMaskedLM(FlaxBertModel): rngs=rngs, ) - return logits, pooled + +class FlaxBertForMaskedLMModule(nn.Module): + vocab_size: int + hidden_size: int + intermediate_size: int + head_size: int + num_heads: int + num_encoder_layers: int + type_vocab_size: int + max_length: int + hidden_act: str + dropout_rate: float = 0.0 + dtype: jnp.dtype = jnp.float32 + + @nn.compact + def __call__( + self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True + ): + # Model + encoder = FlaxBertModule( + vocab_size=self.vocab_size, + type_vocab_size=self.type_vocab_size, + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + head_size=self.hidden_size, + num_heads=self.num_heads, + num_encoder_layers=self.num_encoder_layers, + max_length=self.max_length, + dropout_rate=self.dropout_rate, + hidden_act=self.hidden_act, + dtype=self.dtype, + add_pooling_layer=False, + name="bert", + )(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic) + + # Compute the prediction scores + encoder = nn.Dropout(rate=self.dropout_rate)(encoder, deterministic=deterministic) + logits = FlaxBertOnlyMLMHead( + vocab_size=self.vocab_size, hidden_act=self.hidden_act, name="cls", dtype=self.dtype + )(encoder) + + return (logits,) diff --git a/src/transformers/models/roberta/modeling_flax_roberta.py b/src/transformers/models/roberta/modeling_flax_roberta.py index bafbdfc4d..64fc2bdc4 100644 --- a/src/transformers/models/roberta/modeling_flax_roberta.py +++ b/src/transformers/models/roberta/modeling_flax_roberta.py @@ -19,10 +19,11 @@ import numpy as np import flax.linen as nn import jax import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict from jax.random import PRNGKey from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward -from ...modeling_flax_utils import FlaxPreTrainedModel, gelu +from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel from ...utils import logging from .configuration_roberta import RobertaConfig @@ -33,6 +34,23 @@ _CONFIG_FOR_DOC = "RobertaConfig" _TOKENIZER_FOR_DOC = "RobertaTokenizer" +def create_position_ids_from_input_ids(input_ids, padding_idx): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + input_ids: jnp.ndarray + padding_idx: int + + Returns: jnp.ndarray + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = (input_ids != padding_idx).astype("i4") + incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask + return incremental_indices.astype("i4") + padding_idx + + ROBERTA_START_DOCSTRING = r""" This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the @@ -208,7 +226,7 @@ class FlaxRobertaAttention(nn.Module): dtype: jnp.dtype = jnp.float32 # the dtype of the computation @nn.compact - def __call__(self, hidden_state, attention_mask, deterministic: bool = True): + def __call__(self, hidden_states, attention_mask, deterministic: bool = True): # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) @@ -222,28 +240,29 @@ class FlaxRobertaAttention(nn.Module): bias_init=jax.nn.initializers.zeros, name="self", dtype=self.dtype, - )(hidden_state, attention_mask) + )(hidden_states, attention_mask) - layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_state) + layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_states) return layer_norm # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta class FlaxRobertaIntermediate(nn.Module): output_size: int + hidden_act: str = "gelu" kernel_init_scale: float = 0.2 dtype: jnp.dtype = jnp.float32 # the dtype of the computation @nn.compact - def __call__(self, hidden_state): - # TODO: Add ACT2FN reference to change activation function - dense = nn.Dense( + def __call__(self, hidden_states): + hidden_states = nn.Dense( features=self.output_size, kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype), name="dense", dtype=self.dtype, - )(hidden_state) - return gelu(dense) + )(hidden_states) + hidden_states = ACT2FN[self.hidden_act](hidden_states) + return hidden_states # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Roberta @@ -254,27 +273,28 @@ class FlaxRobertaOutput(nn.Module): @nn.compact def __call__(self, intermediate_output, attention_output, deterministic: bool = True): - hidden_state = nn.Dense( + hidden_states = nn.Dense( attention_output.shape[-1], kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype), name="dense", dtype=self.dtype, )(intermediate_output) - hidden_state = nn.Dropout(rate=self.dropout_rate)(hidden_state, deterministic=deterministic) - hidden_state = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_state + attention_output) - return hidden_state + hidden_states = nn.Dropout(rate=self.dropout_rate)(hidden_states, deterministic=deterministic) + hidden_states = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_states + attention_output) + return hidden_states class FlaxRobertaLayer(nn.Module): num_heads: int head_size: int intermediate_size: int + hidden_act: str = "gelu" dropout_rate: float = 0.0 kernel_init_scale: float = 0.2 dtype: jnp.dtype = jnp.float32 # the dtype of the computation @nn.compact - def __call__(self, hidden_state, attention_mask, deterministic: bool = True): + def __call__(self, hidden_states, attention_mask, deterministic: bool = True): attention = FlaxRobertaAttention( self.num_heads, self.head_size, @@ -282,10 +302,11 @@ class FlaxRobertaLayer(nn.Module): dropout_rate=self.dropout_rate, name="attention", dtype=self.dtype, - )(hidden_state, attention_mask, deterministic=deterministic) + )(hidden_states, attention_mask, deterministic=deterministic) intermediate = FlaxRobertaIntermediate( self.intermediate_size, kernel_init_scale=self.kernel_init_scale, + hidden_act=self.hidden_act, name="intermediate", dtype=self.dtype, )(attention) @@ -306,6 +327,7 @@ class FlaxRobertaLayerCollection(nn.Module): num_heads: int head_size: int intermediate_size: int + hidden_act: str = "gelu" dropout_rate: float = 0.0 kernel_init_scale: float = 0.2 dtype: jnp.dtype = jnp.float32 # the dtype of the computation @@ -325,6 +347,7 @@ class FlaxRobertaLayerCollection(nn.Module): self.intermediate_size, kernel_init_scale=self.kernel_init_scale, dropout_rate=self.dropout_rate, + hidden_act=self.hidden_act, name=f"{i}", dtype=self.dtype, ) @@ -338,22 +361,24 @@ class FlaxRobertaEncoder(nn.Module): num_heads: int head_size: int intermediate_size: int + hidden_act: str = "gelu" dropout_rate: float = 0.0 kernel_init_scale: float = 0.2 dtype: jnp.dtype = jnp.float32 # the dtype of the computation @nn.compact - def __call__(self, hidden_state, attention_mask, deterministic: bool = True): + def __call__(self, hidden_states, attention_mask, deterministic: bool = True): layer = FlaxRobertaLayerCollection( self.num_layers, self.num_heads, self.head_size, self.intermediate_size, + hidden_act=self.hidden_act, kernel_init_scale=self.kernel_init_scale, dropout_rate=self.dropout_rate, name="layer", dtype=self.dtype, - )(hidden_state, attention_mask, deterministic=deterministic) + )(hidden_states, attention_mask, deterministic=deterministic) return layer @@ -363,10 +388,10 @@ class FlaxRobertaPooler(nn.Module): dtype: jnp.dtype = jnp.float32 # the dtype of the computation @nn.compact - def __call__(self, hidden_state): - cls_token = hidden_state[:, 0] + def __call__(self, hidden_states): + cls_token = hidden_states[:, 0] out = nn.Dense( - hidden_state.shape[-1], + hidden_states.shape[-1], kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype), name="dense", dtype=self.dtype, @@ -374,64 +399,12 @@ class FlaxRobertaPooler(nn.Module): return nn.tanh(out) -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta -class FlaxRobertaModule(nn.Module): - vocab_size: int - hidden_size: int - type_vocab_size: int - max_length: int - num_encoder_layers: int - num_heads: int - head_size: int - intermediate_size: int - dropout_rate: float = 0.0 - kernel_init_scale: float = 0.2 - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - @nn.compact - def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True): - - # Embedding - embeddings = FlaxRobertaEmbeddings( - self.vocab_size, - self.hidden_size, - self.type_vocab_size, - self.max_length, - kernel_init_scale=self.kernel_init_scale, - dropout_rate=self.dropout_rate, - name="embeddings", - dtype=self.dtype, - )(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic) - - # N stacked encoding layers - encoder = FlaxRobertaEncoder( - self.num_encoder_layers, - self.num_heads, - self.head_size, - self.intermediate_size, - kernel_init_scale=self.kernel_init_scale, - dropout_rate=self.dropout_rate, - name="encoder", - dtype=self.dtype, - )(embeddings, attention_mask, deterministic=deterministic) - - pooled = FlaxRobertaPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder) - return encoder, pooled - - -@add_start_docstrings( - "The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.", - ROBERTA_START_DOCSTRING, -) -class FlaxRobertaModel(FlaxPreTrainedModel): +class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel): """ - The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of - cross-attention is added between the self-attention layers, following the architecture described in `Attention is - all you need`_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz - Kaiser and Illia Polosukhin. + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. """ - model_class = FlaxRobertaModule config_class = RobertaConfig base_model_prefix = "roberta" @@ -504,7 +477,49 @@ class FlaxRobertaModel(FlaxPreTrainedModel): return jax_state - def __init__(self, config: RobertaConfig, state: dict, seed: int = 0, dtype: jnp.dtype = jnp.float32): + def init(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs( + jnp.zeros(input_shape, dtype="i4"), None, None, None + ) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"] + + def _check_inputs(self, input_ids, attention_mask, token_type_ids, position_ids): + if token_type_ids is None: + token_type_ids = jnp.ones_like(input_ids) + + if position_ids is None: + position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id) + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + return input_ids, attention_mask, token_type_ids, position_ids + + +@add_start_docstrings( + "The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.", + ROBERTA_START_DOCSTRING, +) +class FlaxRobertaModel(FlaxRobertaPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need`_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz + Kaiser and Illia Polosukhin. + """ + + def __init__( + self, + config: RobertaConfig, + input_shape: Tuple = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + **kwargs + ): module = FlaxRobertaModule( vocab_size=config.vocab_size, hidden_size=config.hidden_size, @@ -513,12 +528,14 @@ class FlaxRobertaModel(FlaxPreTrainedModel): num_encoder_layers=config.num_hidden_layers, num_heads=config.num_attention_heads, head_size=config.hidden_size, + hidden_act=config.hidden_act, intermediate_size=config.intermediate_size, dropout_rate=config.hidden_dropout_prob, dtype=dtype, + **kwargs, ) - super().__init__(config, module, state, seed) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) def __call__( @@ -550,42 +567,53 @@ class FlaxRobertaModel(FlaxPreTrainedModel): rngs=rngs, ) - def init(self, rng: jax.random.PRNGKey, input_shape: Tuple): - input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs( - jnp.zeros(input_shape, dtype="i4"), None, None, None - ) - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta +class FlaxRobertaModule(nn.Module): + vocab_size: int + hidden_size: int + type_vocab_size: int + max_length: int + num_encoder_layers: int + num_heads: int + head_size: int + intermediate_size: int + hidden_act: str = "gelu" + dropout_rate: float = 0.0 + kernel_init_scale: float = 0.2 + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + add_pooling_layer: bool = True - self.params = self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"] + @nn.compact + def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True): - def _check_inputs(self, input_ids, attention_mask, token_type_ids, position_ids): + # Embedding + embeddings = FlaxRobertaEmbeddings( + self.vocab_size, + self.hidden_size, + self.type_vocab_size, + self.max_length, + kernel_init_scale=self.kernel_init_scale, + dropout_rate=self.dropout_rate, + name="embeddings", + dtype=self.dtype, + )(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic) - if token_type_ids is None: - token_type_ids = jnp.ones_like(input_ids) + # N stacked encoding layers + encoder = FlaxRobertaEncoder( + self.num_encoder_layers, + self.num_heads, + self.head_size, + self.intermediate_size, + kernel_init_scale=self.kernel_init_scale, + dropout_rate=self.dropout_rate, + hidden_act=self.hidden_act, + name="encoder", + dtype=self.dtype, + )(embeddings, attention_mask, deterministic=deterministic) - if position_ids is None: - position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id) + if not self.add_pooling_layer: + return encoder - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - - return input_ids, attention_mask, token_type_ids, position_ids - - -def create_position_ids_from_input_ids(input_ids, padding_idx): - """ - Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols - are ignored. This is modified from fairseq's `utils.make_positions`. - - Args: - input_ids: jnp.ndarray - padding_idx: int - - Returns: jnp.ndarray - """ - # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. - mask = (input_ids != padding_idx).astype("i4") - incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask - return incremental_indices.astype("i4") + padding_idx + pooled = FlaxRobertaPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder) + return encoder, pooled diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index 3c9b204b1..00773af27 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -2,6 +2,15 @@ from ..file_utils import requires_flax +class FlaxPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_flax(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_flax(self) + + FLAX_MODEL_MAPPING = None diff --git a/tests/test_modeling_flax_bert.py b/tests/test_modeling_flax_bert.py index 2276ef3b5..e201b8db8 100644 --- a/tests/test_modeling_flax_bert.py +++ b/tests/test_modeling_flax_bert.py @@ -14,14 +14,16 @@ import unittest +import numpy as np + from transformers import BertConfig, is_flax_available -from transformers.testing_utils import require_flax +from transformers.testing_utils import require_flax, slow from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask if is_flax_available(): - from transformers.models.bert.modeling_flax_bert import FlaxBertModel + from transformers.models.bert.modeling_flax_bert import FlaxBertForMaskedLM, FlaxBertModel class FlaxBertModelTester(unittest.TestCase): @@ -105,7 +107,14 @@ class FlaxBertModelTester(unittest.TestCase): @require_flax class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase): - all_model_classes = (FlaxBertModel,) if is_flax_available() else () + all_model_classes = (FlaxBertModel, FlaxBertForMaskedLM) if is_flax_available() else () def setUp(self): self.model_tester = FlaxBertModelTester(self) + + @slow + def test_model_from_pretrained(self): + for model_class_name in self.all_model_classes: + model = model_class_name.from_pretrained("bert-base-cased") + outputs = model(np.ones((1, 1))) + self.assertIsNotNone(outputs) diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index c79407527..5b5bf54bd 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -13,6 +13,7 @@ # limitations under the License. import random +import tempfile import numpy as np @@ -26,7 +27,7 @@ if is_flax_available(): import jax import jax.numpy as jnp - from flax.traverse_util import unflatten_dict + from transformers.modeling_flax_utils import convert_state_dict_from_pt os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8 @@ -59,21 +60,13 @@ def random_attention_mask(shape, rng=None): return attn_mask -def convert_pt_model_to_flax(pt_model, config, flax_model_cls): - state = pt_model.state_dict() - state = {k: v.numpy() for k, v in state.items()} - state = flax_model_cls.convert_from_pytorch(state, config) - state = unflatten_dict({tuple(k.split(".")): v for k, v in state.items()}) - return flax_model_cls(config, state, dtype=jnp.float32) - - @require_flax class FlaxModelTesterMixin: model_tester = None all_model_classes = () def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float): - diff = np.abs((a - b)).sum() + diff = np.abs((a - b)).max() self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).") @require_torch @@ -86,30 +79,54 @@ class FlaxModelTesterMixin: pt_model_class = getattr(transformers, pt_model_class_name) pt_model = pt_model_class(config).eval() - fx_model = convert_pt_model_to_flax(pt_model, config, model_class) + fx_state = convert_state_dict_from_pt(model_class, pt_model.state_dict(), config) + fx_model = model_class(config, dtype=jnp.float32) + fx_model.params = fx_state pt_inputs = {k: torch.tensor(v.tolist()) for k, v in inputs_dict.items()} with torch.no_grad(): pt_outputs = pt_model(**pt_inputs).to_tuple() + fx_outputs = fx_model(**inputs_dict) self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") for fx_output, pt_output in zip(fx_outputs, pt_outputs): - self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-3) + self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3) + + with tempfile.TemporaryDirectory() as tmpdirname: + pt_model.save_pretrained(tmpdirname) + fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True) + + fx_outputs_loaded = fx_model_loaded(**inputs_dict) + self.assertEqual( + len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch" + ) + for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs): + self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 5e-3) + + def test_from_pretrained_save_pretrained(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + model = model_class(config) + + outputs = model(**inputs_dict) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_loaded = model_class.from_pretrained(tmpdirname) + + outputs_loaded = model_loaded(**inputs_dict) + for output_loaded, output in zip(outputs_loaded, outputs): + self.assert_almost_equals(output_loaded, output, 5e-3) - @require_torch def test_jit_compilation(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: with self.subTest(model_class.__name__): - - # TODO later: have some way to initialize easily a Flax model from config, for now I go through PT - pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning - pt_model_class = getattr(transformers, pt_model_class_name) - pt_model = pt_model_class(config).eval() - - model = convert_pt_model_to_flax(pt_model, config, model_class) + model = model_class(config) @jax.jit def model_jitted(input_ids, attention_mask=None, token_type_ids=None): @@ -125,3 +142,14 @@ class FlaxModelTesterMixin: self.assertEqual(len(outputs), len(jitted_outputs)) for jitted_output, output in zip(jitted_outputs, outputs): self.assertEqual(jitted_output.shape, output.shape) + + def test_naming_convention(self): + for model_class in self.all_model_classes: + model_class_name = model_class.__name__ + module_class_name = ( + model_class_name[:-5] + "Module" if model_class_name[-5:] == "Model" else model_class_name + "Module" + ) + bert_modeling_flax_module = __import__(model_class.__module__, fromlist=[module_class_name]) + module_cls = getattr(bert_modeling_flax_module, module_class_name) + + self.assertIsNotNone(module_cls) diff --git a/tests/test_modeling_flax_roberta.py b/tests/test_modeling_flax_roberta.py index fa20d9fa3..318d934ce 100644 --- a/tests/test_modeling_flax_roberta.py +++ b/tests/test_modeling_flax_roberta.py @@ -14,8 +14,10 @@ import unittest +import numpy as np + from transformers import RobertaConfig, is_flax_available -from transformers.testing_utils import require_flax +from transformers.testing_utils import require_flax, slow from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask @@ -109,3 +111,10 @@ class FlaxRobertaModelTest(FlaxModelTesterMixin, unittest.TestCase): def setUp(self): self.model_tester = FlaxRobertaModelTester(self) + + @slow + def test_model_from_pretrained(self): + for model_class_name in self.all_model_classes: + model = model_class_name.from_pretrained("roberta-base") + outputs = model(np.ones((1, 1))) + self.assertIsNotNone(outputs)