mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Flax Speech-Encoder-Decoder Model (#15613)
* rebase * Delete shift tokens func * downsample decoder input seq len for init * correct attention mask * add tests * pt flax cross test * make fixup * init file for import * change pt-flax cross test threshold * pt-flax test logits only * move tests * make repo-consistency * consistent indentation Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
935a76d90d
commit
e3342edc4e
10 changed files with 1509 additions and 2 deletions
|
|
@ -230,7 +230,7 @@ Flax), PyTorch, and/or TensorFlow.
|
|||
| SegFormer | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| SEW | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| SEW-D | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| Speech Encoder decoder | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| Speech Encoder decoder | ❌ | ❌ | ✅ | ❌ | ✅ |
|
||||
| Speech2Text | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| Speech2Text2 | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| Splinter | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
|
|
|
|||
|
|
@ -33,3 +33,9 @@ An example of how to use a [`SpeechEncoderDecoderModel`] for inference can be se
|
|||
[[autodoc]] SpeechEncoderDecoderModel
|
||||
- forward
|
||||
- from_encoder_decoder_pretrained
|
||||
|
||||
## FlaxSpeechEncoderDecoderModel
|
||||
|
||||
[[autodoc]] FlaxSpeechEncoderDecoderModel
|
||||
- __call__
|
||||
- from_encoder_decoder_pretrained
|
||||
|
|
@ -2295,6 +2295,7 @@ if is_flax_available():
|
|||
"FlaxRoFormerPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.speech_encoder_decoder"].append("FlaxSpeechEncoderDecoderModel")
|
||||
_import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"])
|
||||
_import_structure["models.vision_encoder_decoder"].append("FlaxVisionEncoderDecoderModel")
|
||||
_import_structure["models.vision_text_dual_encoder"].extend(["FlaxVisionTextDualEncoderModel"])
|
||||
|
|
@ -4183,6 +4184,7 @@ if TYPE_CHECKING:
|
|||
FlaxRoFormerModel,
|
||||
FlaxRoFormerPreTrainedModel,
|
||||
)
|
||||
from .models.speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
|
||||
from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
|
||||
from .models.vision_encoder_decoder import FlaxVisionEncoderDecoderModel
|
||||
from .models.vision_text_dual_encoder import FlaxVisionTextDualEncoderModel
|
||||
|
|
|
|||
|
|
@ -188,6 +188,12 @@ FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
|
|||
]
|
||||
)
|
||||
|
||||
FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("speech-encoder-decoder", "FlaxSpeechEncoderDecoderModel"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES)
|
||||
FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
|
||||
|
|
@ -215,6 +221,9 @@ FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
|
|||
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
|
||||
)
|
||||
FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
|
||||
)
|
||||
|
||||
|
||||
class FlaxAutoModel(_BaseAutoModelClass):
|
||||
|
|
@ -309,3 +318,12 @@ class FlaxAutoModelForVision2Seq(_BaseAutoModelClass):
|
|||
|
||||
|
||||
FlaxAutoModelForVision2Seq = auto_class_update(FlaxAutoModelForVision2Seq, head_doc="vision-to-text modeling")
|
||||
|
||||
|
||||
class FlaxAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
|
||||
_model_mapping = FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
|
||||
|
||||
|
||||
FlaxAutoModelForSpeechSeq2Seq = auto_class_update(
|
||||
FlaxAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@
|
|||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...file_utils import _LazyModule, is_torch_available
|
||||
from ...file_utils import _LazyModule, is_flax_available, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
|
|
@ -28,12 +28,18 @@ _import_structure = {
|
|||
if is_torch_available():
|
||||
_import_structure["modeling_speech_encoder_decoder"] = ["SpeechEncoderDecoderModel"]
|
||||
|
||||
if is_flax_available():
|
||||
_import_structure["modeling_flax_speech_encoder_decoder"] = ["FlaxSpeechEncoderDecoderModel"]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_speech_encoder_decoder import SpeechEncoderDecoderConfig
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_speech_encoder_decoder import SpeechEncoderDecoderModel
|
||||
|
||||
if is_flax_available():
|
||||
from .modeling_flax_speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,891 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Classes to support Flax Speech-Encoder-Decoder architectures"""
|
||||
|
||||
import os
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
||||
from jax import lax
|
||||
from jax.random import PRNGKey
|
||||
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput
|
||||
from ...modeling_flax_utils import FlaxPreTrainedModel
|
||||
from ...utils import logging
|
||||
from ..auto.configuration_auto import AutoConfig
|
||||
from ..auto.modeling_flax_auto import FlaxAutoModel, FlaxAutoModelForCausalLM
|
||||
from .configuration_speech_encoder_decoder import SpeechEncoderDecoderConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "SpeechEncoderDecoderConfig"
|
||||
|
||||
SPEECH_ENCODER_DECODER_START_DOCSTRING = r"""
|
||||
This class can be used to initialize a speech-sequence-to-text-sequence model with any pretrained speech
|
||||
autoencoding model as the encoder and any pretrained text autoregressive model as the decoder. The encoder is
|
||||
loaded via [`~AutoModel.from_pretrained`] function and the decoder is loaded via
|
||||
[`~AutoModelForCausalLM.from_pretrained`] function. Cross-attention layers are automatically added to the decoder
|
||||
and should be fine-tuned on a downstream generative task, like summarization.
|
||||
|
||||
The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation
|
||||
tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation
|
||||
Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi
|
||||
Zhou, Wei Li, Peter J. Liu.
|
||||
|
||||
Additionally, in [Large-Scale Self- and Semi-Supervised Learning for Speech
|
||||
Translation](https://arxiv.org/abs/2104.06678) it is shown how leveraging large pretrained speech models for speech
|
||||
translation yields a significant performance improvement.
|
||||
|
||||
After such an Speech-Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other
|
||||
models (see the examples for more information).
|
||||
|
||||
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a Flax Linen
|
||||
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
|
||||
regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`SpeechEncoderDecoderConfig`]): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
|
||||
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
|
||||
`jax.numpy.bfloat16` (on TPUs).
|
||||
|
||||
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
||||
specified all the computation will be performed with the given `dtype`.
|
||||
|
||||
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
||||
parameters.**
|
||||
|
||||
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
|
||||
[`~FlaxPreTrainedModel.to_bf16`].
|
||||
"""
|
||||
|
||||
SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*):
|
||||
Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac*
|
||||
or *.wav* audio file into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile
|
||||
library (*pip install soundfile*). To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or
|
||||
[`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type
|
||||
*torch.FloatTensor*.
|
||||
attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||
Indices of decoder input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
|
||||
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
||||
`past_key_values`).
|
||||
|
||||
For training, `decoder_input_ids` are automatically created by the model by shifting the `labels` to the
|
||||
right, replacing -100 by the `pad_token_id` and prepending them with the `decoder_start_token_id`.
|
||||
decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
||||
be used by default.
|
||||
decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
|
||||
range `[0, config.decoder.max_position_embeddings - 1]`.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
If set to `True`, the model will return a [`~file_utils.FlaxSeq2SeqLMOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*):
|
||||
Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac*
|
||||
or *.wav* audio file into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile
|
||||
library (*pip install soundfile*). To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or
|
||||
[`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type
|
||||
*torch.FloatTensor*.
|
||||
attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
If set to `True`, the model will return a [`~file_utils.FlaxBaseModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||
Indices of decoder input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are decoder input IDs?](../glossary#decoder-input-ids)
|
||||
|
||||
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
||||
`past_key_values`).
|
||||
|
||||
For sequence to sequence training, `decoder_input_ids` should be provided. If no `decoder_input_ids` is
|
||||
provided, the model will create this tensor by shifting the `input_ids` to the right for denoising
|
||||
pre-training.
|
||||
encoder_outputs (`tuple(tuple(jnp.ndarray)`):
|
||||
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
|
||||
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
|
||||
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
||||
encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
||||
be used by default.
|
||||
decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
|
||||
range `[0, config.decoder.max_position_embeddings - 1]`.
|
||||
past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
|
||||
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
|
||||
auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
If set to `True`, the model will return a [`~file_utils.FlaxCausalLMOutputWithCrossAttentions`] instead of
|
||||
a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
class FlaxSpeechEncoderDecoderModule(nn.Module):
|
||||
config: SpeechEncoderDecoderConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
encoder_config = self.config.encoder
|
||||
decoder_config = self.config.decoder
|
||||
|
||||
# Copied from `modeling_hybrid_clip.py` with modifications.
|
||||
from ...models.auto.modeling_flax_auto import FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, FLAX_MODEL_MAPPING
|
||||
|
||||
encoder_module = FLAX_MODEL_MAPPING[encoder_config.__class__].module_class
|
||||
decoder_module = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING[decoder_config.__class__].module_class
|
||||
|
||||
self.encoder = encoder_module(encoder_config, dtype=self.dtype)
|
||||
self.decoder = decoder_module(decoder_config, dtype=self.dtype)
|
||||
|
||||
# encoder outputs might need to be projected to different dimension for decoder
|
||||
if (
|
||||
self.encoder.config.hidden_size != self.decoder.config.hidden_size
|
||||
and self.decoder.config.cross_attention_hidden_size is None
|
||||
):
|
||||
self.enc_to_dec_proj = nn.Dense(
|
||||
self.decoder.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
else:
|
||||
self.enc_to_dec_proj = None
|
||||
|
||||
def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]):
|
||||
"""
|
||||
Computes the output length of the convolutional layers
|
||||
"""
|
||||
|
||||
def _conv_out_length(input_length, kernel_size, stride):
|
||||
# 1D convolutional layer output length formula taken
|
||||
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
||||
return (input_length - kernel_size) // stride + 1
|
||||
|
||||
for kernel_size, stride in zip(self.config.encoder.conv_kernel, self.config.encoder.conv_stride):
|
||||
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
|
||||
|
||||
return input_lengths
|
||||
|
||||
def _get_encoder_module(self):
|
||||
return self.encoder
|
||||
|
||||
def _get_projection_module(self):
|
||||
return self.enc_to_dec_proj
|
||||
|
||||
def _get_decoder_module(self):
|
||||
return self.decoder
|
||||
|
||||
def freeze_feature_encoder(self):
|
||||
"""
|
||||
Calling this function will disable the gradient computation for the feature encoder of the speech encoder in
|
||||
order that its parameters are not updated during training.
|
||||
"""
|
||||
self.encoder.freeze_feature_encoder()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs,
|
||||
attention_mask,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
decoder_position_ids,
|
||||
encoder_outputs=None,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
deterministic: bool = True,
|
||||
):
|
||||
if encoder_outputs is None:
|
||||
encoder_outputs = self.encoder(
|
||||
inputs,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
deterministic=deterministic,
|
||||
)
|
||||
|
||||
encoder_hidden_states = encoder_outputs[0]
|
||||
|
||||
# optionally project encoder_hidden_states
|
||||
if self.enc_to_dec_proj is not None:
|
||||
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
|
||||
|
||||
# compute correct encoder attention mask
|
||||
if attention_mask is not None:
|
||||
encoder_attention_mask = self.encoder._get_feature_vector_attention_mask(
|
||||
encoder_hidden_states.shape[1], attention_mask
|
||||
)
|
||||
else:
|
||||
encoder_attention_mask = None
|
||||
|
||||
# flax script modeling_flax_wav2vec2.py
|
||||
decoder_outputs = self.decoder(
|
||||
input_ids=decoder_input_ids,
|
||||
attention_mask=decoder_attention_mask,
|
||||
position_ids=decoder_position_ids,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
deterministic=deterministic,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return decoder_outputs + encoder_outputs
|
||||
|
||||
return FlaxSeq2SeqLMOutput(
|
||||
logits=decoder_outputs.logits,
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
cross_attentions=decoder_outputs.cross_attentions,
|
||||
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
||||
encoder_hidden_states=encoder_outputs.hidden_states,
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(SPEECH_ENCODER_DECODER_START_DOCSTRING)
|
||||
class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
|
||||
r"""
|
||||
[`FlaxSpeechEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture
|
||||
with the module (flax.nn.Module) of one of the base model classes of the library as encoder module and another one
|
||||
as decoder module when created with the :meth*~transformers.FlaxAutoModel.from_pretrained* class method for the
|
||||
encoder and :meth*~transformers.FlaxAutoModelForCausalLM.from_pretrained* class method for the decoder.
|
||||
"""
|
||||
|
||||
config_class = SpeechEncoderDecoderConfig
|
||||
base_model_prefix: str = "speech_encoder_decoder"
|
||||
module_class = FlaxSpeechEncoderDecoderModule
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SpeechEncoderDecoderConfig,
|
||||
input_shape: Optional[Tuple] = None,
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
**kwargs
|
||||
):
|
||||
if config.decoder.cross_attention_hidden_size is not None:
|
||||
# Raise ValueError or option to project enc to dec hidden_size (eg EncAdapterLayer)
|
||||
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
|
||||
raise ValueError(
|
||||
"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
|
||||
"it has to be equal to the encoder's `hidden_size`. "
|
||||
f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
|
||||
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
|
||||
)
|
||||
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
|
||||
if input_shape is None:
|
||||
# speech encoders almost always downsample the sequence length dimension
|
||||
encoder_input_length = 1024
|
||||
decoder_input_length = module._get_feat_extract_output_lengths(encoder_input_length)
|
||||
input_shape = ((1, encoder_input_length), (1, decoder_input_length))
|
||||
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
encoder_input_shape, decoder_input_shape = input_shape
|
||||
|
||||
# init input DeviceArrays
|
||||
inputs = jnp.zeros(encoder_input_shape, dtype="i4")
|
||||
attention_mask = jnp.ones_like(inputs)
|
||||
decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4")
|
||||
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
|
||||
|
||||
batch_size, sequence_length = inputs.shape
|
||||
|
||||
decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape
|
||||
if not decoder_batch_size == batch_size:
|
||||
raise ValueError(
|
||||
f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder and {decoder_batch_size} for decoder."
|
||||
)
|
||||
decoder_position_ids = jnp.broadcast_to(
|
||||
jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length)
|
||||
)
|
||||
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(
|
||||
rngs,
|
||||
inputs,
|
||||
attention_mask,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
decoder_position_ids,
|
||||
)["params"]
|
||||
|
||||
def init_cache(self, batch_size, max_length, encoder_outputs):
|
||||
r"""
|
||||
Args:
|
||||
batch_size (`int`):
|
||||
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
|
||||
max_length (`int`):
|
||||
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
|
||||
cache.
|
||||
encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
|
||||
`encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
|
||||
`attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
|
||||
is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
|
||||
cross-attention of the decoder.
|
||||
"""
|
||||
# init input variables to retrieve cache
|
||||
decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
|
||||
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
|
||||
decoder_position_ids = jnp.broadcast_to(
|
||||
jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
|
||||
)
|
||||
|
||||
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
|
||||
decoder_module = module._get_decoder_module()
|
||||
return decoder_module(
|
||||
input_ids=decoder_input_ids,
|
||||
attention_mask=decoder_attention_mask,
|
||||
position_ids=decoder_position_ids,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
init_variables = self.module.init(
|
||||
jax.random.PRNGKey(0),
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
decoder_position_ids=decoder_position_ids,
|
||||
encoder_hidden_states=encoder_outputs[0],
|
||||
init_cache=True,
|
||||
method=_decoder_forward, # we only need to call the decoder to init the cache
|
||||
)
|
||||
return unfreeze(init_variables["cache"])
|
||||
|
||||
def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]):
|
||||
return self.module._get_feat_extract_output_lengths(input_lengths)
|
||||
|
||||
@add_start_docstrings(SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def encode(
|
||||
self,
|
||||
inputs: jnp.ndarray,
|
||||
attention_mask: Optional[jnp.ndarray] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
train: bool = False,
|
||||
params: dict = None,
|
||||
dropout_rng: PRNGKey = None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import FlaxSpeechEncoderDecoderModel
|
||||
|
||||
>>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized
|
||||
>>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
... "facebook/wav2vec2-large-lv60", "facebook/bart-large"
|
||||
... )
|
||||
|
||||
>>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)
|
||||
>>> encoder_outputs = model.encode(inputs)
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = jnp.ones_like(inputs)
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
def _encoder_forward(module, inputs, attention_mask, **kwargs):
|
||||
encode_module = module._get_encoder_module()
|
||||
return encode_module(inputs, attention_mask, **kwargs)
|
||||
|
||||
outputs = self.module.apply(
|
||||
{"params": params or self.params},
|
||||
inputs=jnp.array(inputs, dtype="i4"),
|
||||
attention_mask=jnp.array(attention_mask, dtype="i4"),
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
deterministic=not train,
|
||||
rngs=rngs,
|
||||
method=_encoder_forward,
|
||||
)
|
||||
|
||||
if return_dict:
|
||||
outputs = FlaxBaseModelOutput(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
@add_start_docstrings(SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||
def decode(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
encoder_outputs,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
decoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
decoder_position_ids: Optional[jnp.ndarray] = None,
|
||||
past_key_values: dict = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
train: bool = False,
|
||||
params: dict = None,
|
||||
dropout_rng: PRNGKey = None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import FlaxSpeechEncoderDecoderModel
|
||||
>>> import jax.numpy as jnp
|
||||
|
||||
>>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized
|
||||
>>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
... "facebook/wav2vec2-large-lv60", "facebook/bart-large"
|
||||
... )
|
||||
|
||||
>>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)
|
||||
>>> encoder_outputs = model.encode(inputs)
|
||||
|
||||
>>> decoder_start_token_id = model.config.decoder.bos_token_id
|
||||
>>> decoder_input_ids = jnp.ones((inputs.shape[0], 1), dtype="i4") * decoder_start_token_id
|
||||
|
||||
>>> outputs = model.decode(decoder_input_ids, encoder_outputs)
|
||||
>>> logits = outputs.logits
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
|
||||
encoder_hidden_states = encoder_outputs[0]
|
||||
if encoder_attention_mask is None:
|
||||
batch_size, sequence_length = encoder_hidden_states.shape[:2]
|
||||
encoder_attention_mask = jnp.ones((batch_size, sequence_length))
|
||||
|
||||
batch_size, sequence_length = decoder_input_ids.shape
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = jnp.ones((batch_size, sequence_length))
|
||||
|
||||
if decoder_position_ids is None:
|
||||
if past_key_values is not None:
|
||||
raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
|
||||
|
||||
decoder_position_ids = jnp.broadcast_to(
|
||||
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
|
||||
)
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
params = {"params": params or self.params}
|
||||
|
||||
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be
|
||||
# passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
|
||||
# it can be changed by FlaxBartAttention module
|
||||
if past_key_values:
|
||||
params["cache"] = past_key_values
|
||||
mutable = ["cache"]
|
||||
else:
|
||||
mutable = False
|
||||
|
||||
def _decoder_forward(
|
||||
module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs
|
||||
):
|
||||
|
||||
projection_module = module._get_projection_module()
|
||||
decoder_module = module._get_decoder_module()
|
||||
|
||||
# optionally project encoder_hidden_states
|
||||
if projection_module is not None:
|
||||
encoder_hidden_states = projection_module(encoder_hidden_states)
|
||||
|
||||
return decoder_module(
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
decoder_position_ids,
|
||||
encoder_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
outputs = self.module.apply(
|
||||
params,
|
||||
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
|
||||
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
|
||||
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
deterministic=not train,
|
||||
rngs=rngs,
|
||||
mutable=mutable,
|
||||
method=_decoder_forward,
|
||||
)
|
||||
|
||||
# add updated cache to model output
|
||||
if past_key_values is not None and return_dict:
|
||||
outputs, past = outputs
|
||||
outputs["past_key_values"] = unfreeze(past["cache"])
|
||||
return outputs
|
||||
elif past_key_values is not None and not return_dict:
|
||||
outputs, past = outputs
|
||||
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
|
||||
|
||||
return outputs
|
||||
|
||||
@add_start_docstrings_to_model_forward(SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def __call__(
|
||||
self,
|
||||
inputs: jnp.ndarray,
|
||||
attention_mask: Optional[jnp.ndarray] = None,
|
||||
decoder_input_ids: Optional[jnp.ndarray] = None,
|
||||
decoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
decoder_position_ids: Optional[jnp.ndarray] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
train: bool = False,
|
||||
params: dict = None,
|
||||
dropout_rng: PRNGKey = None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import FlaxSpeechEncoderDecoderModel, BartTokenizer
|
||||
|
||||
>>> # load a fine-tuned wav2vec2-2-bart model
|
||||
>>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("patrickvonplaten/wav2vec2-2-bart-large")
|
||||
>>> # load output tokenizer
|
||||
>>> tokenizer_output = BartTokenizer.from_pretrained("facebook/bart-large")
|
||||
|
||||
>>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)
|
||||
|
||||
>>> # use bart's special bos, pad and eos tokens
|
||||
>>> model.config.decoder_start_token_id = model.decoder.config.bos_token_id
|
||||
>>> model.config.pad_token_id = model.decoder.config.pad_token_id
|
||||
>>> model.config.eos_token_id = model.decoder.config.eos_token_id
|
||||
|
||||
>>> outputs = model.generate(inputs)
|
||||
# Assert something? More interesting input? dtype correct?
|
||||
```
|
||||
"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
|
||||
# prepare encoder inputs
|
||||
if attention_mask is None:
|
||||
attention_mask = jnp.ones_like(inputs)
|
||||
|
||||
# prepare decoder inputs
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
|
||||
if decoder_position_ids is None:
|
||||
batch_size, sequence_length = decoder_input_ids.shape
|
||||
decoder_position_ids = jnp.broadcast_to(
|
||||
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
|
||||
)
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
inputs=jnp.array(inputs, dtype="i4"),
|
||||
attention_mask=jnp.array(attention_mask, dtype="i4"),
|
||||
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
|
||||
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
|
||||
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
deterministic=not train,
|
||||
rngs=rngs,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
max_length,
|
||||
attention_mask: Optional[jnp.DeviceArray] = None,
|
||||
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
||||
encoder_outputs=None,
|
||||
**kwargs
|
||||
):
|
||||
# initializing the cache
|
||||
batch_size, seq_length = decoder_input_ids.shape
|
||||
|
||||
past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
|
||||
# Note that usually one would have to put 0's in the attention_mask for x > input.shape[-1] and x < cache_length.
|
||||
# But since the decoder uses a causal mask, those positions are masked anyways.
|
||||
# Thus we can create a single static attention_mask here, which is more efficient for compilation
|
||||
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
|
||||
if decoder_attention_mask is not None:
|
||||
decoder_position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
|
||||
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
|
||||
else:
|
||||
decoder_position_ids = jnp.broadcast_to(
|
||||
jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
|
||||
)
|
||||
|
||||
return {
|
||||
"past_key_values": past_key_values,
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"encoder_attention_mask": attention_mask,
|
||||
"decoder_attention_mask": extended_attention_mask,
|
||||
"decoder_position_ids": decoder_position_ids,
|
||||
}
|
||||
|
||||
def update_inputs_for_generation(self, model_outputs, model_kwargs):
|
||||
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
||||
model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
|
||||
return model_kwargs
|
||||
|
||||
@classmethod
|
||||
def from_encoder_decoder_pretrained(
|
||||
cls,
|
||||
encoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
|
||||
decoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
|
||||
*model_args,
|
||||
**kwargs
|
||||
) -> FlaxPreTrainedModel:
|
||||
r"""
|
||||
Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model
|
||||
checkpoints.
|
||||
|
||||
Params:
|
||||
encoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*):
|
||||
Information necessary to initiate the encoder. Can be either:
|
||||
|
||||
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
||||
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
|
||||
user or organization name, like `dbmdz/bert-base-german-cased`.
|
||||
- A path to a *directory* containing model weights saved using
|
||||
[`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
|
||||
|
||||
decoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*, defaults to `None`):
|
||||
Information necessary to initiate the decoder. Can be either:
|
||||
|
||||
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
||||
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
|
||||
user or organization name, like `dbmdz/bert-base-german-cased`.
|
||||
- A path to a *directory* containing model weights saved using
|
||||
[`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
|
||||
|
||||
model_args (remaining positional arguments, *optional*):
|
||||
All remaning positional arguments will be passed to the underlying model's `__init__` method.
|
||||
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||
`output_attentions=True`).
|
||||
|
||||
- To update the encoder configuration, use the prefix *encoder_* for each configuration parameter.
|
||||
- To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
|
||||
- To update the parent model configuration, do not use a prefix for each configuration parameter.
|
||||
|
||||
Behaves differently depending on whether a `config` is provided or automatically loaded.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import FlaxSpeechEncoderDecoderModel
|
||||
|
||||
>>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized
|
||||
>>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
... "facebook/wav2vec2-large-lv60", "facebook/bart-large"
|
||||
... )
|
||||
>>> # saving model after fine-tuning
|
||||
>>> model.save_pretrained("./wav2vec2-2-bart-large")
|
||||
>>> # load fine-tuned model
|
||||
>>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("./wav2vec2-2-bart-large")
|
||||
```"""
|
||||
|
||||
kwargs_encoder = {
|
||||
argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
|
||||
}
|
||||
|
||||
kwargs_decoder = {
|
||||
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
|
||||
}
|
||||
|
||||
# remove encoder, decoder kwargs from kwargs
|
||||
for key in kwargs_encoder.keys():
|
||||
del kwargs["encoder_" + key]
|
||||
for key in kwargs_decoder.keys():
|
||||
del kwargs["decoder_" + key]
|
||||
|
||||
# Load and initialize the encoder and decoder
|
||||
# The distinction between encoder and decoder at the model level is made
|
||||
# by the value of the flag `is_decoder` that we need to set correctly.
|
||||
encoder = kwargs_encoder.pop("model", None)
|
||||
if encoder is None:
|
||||
if encoder_pretrained_model_name_or_path is None:
|
||||
raise ValueError(
|
||||
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
|
||||
"to be defined."
|
||||
)
|
||||
|
||||
if "config" not in kwargs_encoder:
|
||||
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
|
||||
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
||||
logger.info(
|
||||
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
|
||||
"from a decoder model. Cross-attention and casual mask are disabled."
|
||||
)
|
||||
encoder_config.is_decoder = False
|
||||
encoder_config.add_cross_attention = False
|
||||
|
||||
kwargs_encoder["config"] = encoder_config
|
||||
|
||||
encoder = FlaxAutoModel.from_pretrained(
|
||||
encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder
|
||||
)
|
||||
|
||||
decoder = kwargs_decoder.pop("model", None)
|
||||
if decoder is None:
|
||||
if decoder_pretrained_model_name_or_path is None:
|
||||
raise ValueError(
|
||||
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
|
||||
"to be defined."
|
||||
)
|
||||
|
||||
if "config" not in kwargs_decoder:
|
||||
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
|
||||
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
||||
logger.info(
|
||||
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
|
||||
f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
|
||||
f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
|
||||
"cross attention layers."
|
||||
)
|
||||
decoder_config.is_decoder = True
|
||||
decoder_config.add_cross_attention = True
|
||||
|
||||
kwargs_decoder["config"] = decoder_config
|
||||
|
||||
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
|
||||
logger.warning(
|
||||
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
|
||||
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
|
||||
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
|
||||
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
|
||||
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
||||
)
|
||||
|
||||
decoder = FlaxAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
||||
|
||||
# instantiate config with corresponding kwargs
|
||||
dtype = kwargs.pop("dtype", jnp.float32)
|
||||
config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
|
||||
|
||||
# init model
|
||||
model = cls(config, dtype=dtype)
|
||||
model.params["encoder"] = encoder.params
|
||||
model.params["decoder"] = decoder.params
|
||||
|
||||
return model
|
||||
|
|
@ -1012,6 +1012,26 @@ class FlaxWav2Vec2Module(nn.Module):
|
|||
|
||||
return input_lengths
|
||||
|
||||
def _get_feature_vector_attention_mask(
|
||||
self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
|
||||
):
|
||||
|
||||
# Effectively attention_mask.sum(-1), but not inplace to be able to run
|
||||
# on inference mode.
|
||||
non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1]
|
||||
|
||||
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
|
||||
|
||||
batch_size = attention_mask.shape[0]
|
||||
|
||||
attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
|
||||
# these two operations makes sure that all values before the output lengths idxs are attended to
|
||||
attention_mask = attention_mask.at[(jnp.arange(attention_mask.shape[0]), output_lengths - 1)].set(1)
|
||||
attention_mask = jnp.flip(jnp.flip(attention_mask, axis=-1).cumsum(axis=-1), axis=-1)
|
||||
|
||||
attention_mask = jnp.array(attention_mask, dtype=bool)
|
||||
return attention_mask
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Wav2Vec2 Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
|
|
|
|||
|
|
@ -879,6 +879,13 @@ class FlaxRoFormerPreTrainedModel(metaclass=DummyObject):
|
|||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxSpeechEncoderDecoderModel(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxT5ForConditionalGeneration(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,555 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import is_flax_available, is_torch_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow, torch_device
|
||||
|
||||
from ..gpt2.test_modeling_flax_gpt2 import FlaxGPT2ModelTester
|
||||
from ..test_modeling_flax_common import floats_tensor, ids_tensor, random_attention_mask
|
||||
from ..wav2vec2.test_modeling_flax_wav2vec2 import FlaxWav2Vec2ModelTester
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
from transformers import (
|
||||
FlaxGPT2LMHeadModel,
|
||||
FlaxSpeechEncoderDecoderModel,
|
||||
FlaxWav2Vec2Model,
|
||||
SpeechEncoderDecoderConfig,
|
||||
)
|
||||
from transformers.modeling_flax_outputs import FlaxBaseModelOutput
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
)
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import SpeechEncoderDecoderModel
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxEncoderDecoderMixin:
|
||||
def get_encoder_decoder_model(self, config, decoder_config):
|
||||
raise NotImplementedError
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_pretrained_model(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def check_encoder_decoder_model_from_pretrained_configs(
|
||||
self,
|
||||
config,
|
||||
inputs,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
**kwargs
|
||||
):
|
||||
encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||
self.assertTrue(encoder_decoder_config.decoder.is_decoder)
|
||||
|
||||
enc_dec_model = FlaxSpeechEncoderDecoderModel(encoder_decoder_config)
|
||||
|
||||
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
|
||||
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
inputs=inputs,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model(
|
||||
self,
|
||||
config,
|
||||
inputs,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
**kwargs
|
||||
):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = SpeechEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
self.assertTrue(enc_dec_model.config.decoder.is_decoder)
|
||||
self.assertTrue(enc_dec_model.config.decoder.add_cross_attention)
|
||||
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
|
||||
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
inputs=inputs,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
||||
)
|
||||
|
||||
encoder_outputs = FlaxBaseModelOutput(last_hidden_state=outputs_encoder_decoder.encoder_hidden_states[-1])
|
||||
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
attention_mask, decoder_input_ids, decoder_attention_mask, encoder_outputs=encoder_outputs
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model_from_pretrained(
|
||||
self,
|
||||
config,
|
||||
inputs,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
return_dict,
|
||||
**kwargs
|
||||
):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model, "return_dict": return_dict}
|
||||
enc_dec_model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
inputs=inputs,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
||||
)
|
||||
|
||||
def check_save_and_load(
|
||||
self,
|
||||
config,
|
||||
inputs,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
**kwargs
|
||||
):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model}
|
||||
enc_dec_model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs)
|
||||
|
||||
outputs = enc_dec_model(
|
||||
inputs=inputs,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
out_2 = np.array(outputs[0])
|
||||
out_2[np.isnan(out_2)] = 0
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
enc_dec_model.save_pretrained(tmpdirname)
|
||||
FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname)
|
||||
|
||||
after_outputs = enc_dec_model(
|
||||
inputs=inputs,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
out_1 = np.array(after_outputs[0])
|
||||
out_1[np.isnan(out_1)] = 0
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
self.assertLessEqual(max_diff, 4e-2)
|
||||
|
||||
def check_encoder_decoder_model_output_attentions(
|
||||
self,
|
||||
config,
|
||||
inputs,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
**kwargs
|
||||
):
|
||||
# make the decoder inputs a different shape from the encoder inputs to harden the test
|
||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model}
|
||||
enc_dec_model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
inputs=inputs,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_attentions=True,
|
||||
)
|
||||
|
||||
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
|
||||
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
|
||||
|
||||
seq_len = enc_dec_model._get_feat_extract_output_lengths(inputs.shape[1])
|
||||
self.assertEqual(encoder_attentions[0].shape[-3:], (config.num_attention_heads, seq_len, seq_len))
|
||||
|
||||
decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
|
||||
num_decoder_layers = (
|
||||
decoder_config.num_decoder_layers
|
||||
if hasattr(decoder_config, "num_decoder_layers")
|
||||
else decoder_config.num_hidden_layers
|
||||
)
|
||||
self.assertEqual(len(decoder_attentions), num_decoder_layers)
|
||||
|
||||
self.assertEqual(
|
||||
decoder_attentions[0].shape[-3:],
|
||||
(decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]),
|
||||
)
|
||||
|
||||
cross_attentions = outputs_encoder_decoder["cross_attentions"]
|
||||
self.assertEqual(len(cross_attentions), num_decoder_layers)
|
||||
|
||||
cross_attention_input_seq_len = decoder_input_ids.shape[-1]
|
||||
|
||||
self.assertEqual(
|
||||
cross_attentions[0].shape[-3:],
|
||||
(decoder_config.num_attention_heads, cross_attention_input_seq_len, seq_len),
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model_generate(self, inputs, config, decoder_config, **kwargs):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model}
|
||||
enc_dec_model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs)
|
||||
|
||||
pad_token_id = enc_dec_model.config.decoder.pad_token_id
|
||||
eos_token_id = enc_dec_model.config.decoder.eos_token_id
|
||||
decoder_start_token_id = enc_dec_model.config.decoder.decoder_start_token_id
|
||||
|
||||
# Copied from generation_utils (GPT2 doesn't have `pad_token_id`)
|
||||
if pad_token_id is None and eos_token_id is not None:
|
||||
pad_token_id = eos_token_id
|
||||
if decoder_start_token_id is None:
|
||||
decoder_start_token_id = enc_dec_model.config.decoder.bos_token_id
|
||||
|
||||
# Bert does not have a bos token id, so use pad_token_id instead
|
||||
# Copied from `test_modeling_encoder_decoder.py`
|
||||
if decoder_start_token_id is None:
|
||||
decoder_start_token_id = pad_token_id
|
||||
|
||||
generated_output = enc_dec_model.generate(
|
||||
inputs,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
)
|
||||
generated_sequences = generated_output.sequences
|
||||
self.assertEqual(generated_sequences.shape, (inputs.shape[0],) + (decoder_config.max_length,))
|
||||
|
||||
def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
|
||||
|
||||
pt_model.to(torch_device)
|
||||
pt_model.eval()
|
||||
|
||||
# prepare inputs
|
||||
flax_inputs = inputs_dict
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs)
|
||||
pt_logits = pt_outputs.logits
|
||||
pt_outputs = pt_outputs.to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**inputs_dict)
|
||||
fx_logits = fx_outputs.logits
|
||||
fx_outputs = fx_outputs.to_tuple()
|
||||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
self.assert_almost_equals(fx_logits, pt_logits.numpy(), 4e-2)
|
||||
|
||||
# PT -> Flax
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**inputs_dict)
|
||||
fx_logits_loaded = fx_outputs_loaded.logits
|
||||
fx_outputs_loaded = fx_outputs_loaded.to_tuple()
|
||||
|
||||
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
self.assert_almost_equals(fx_logits_loaded, pt_logits.numpy(), 4e-2)
|
||||
|
||||
# Flax -> PT
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = SpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
pt_model_loaded.to(torch_device)
|
||||
pt_model_loaded.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
|
||||
pt_logits_loaded = pt_outputs_loaded.logits
|
||||
pt_outputs_loaded = pt_outputs_loaded.to_tuple()
|
||||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
|
||||
|
||||
def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
|
||||
|
||||
encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||
|
||||
pt_model = SpeechEncoderDecoderModel(encoder_decoder_config)
|
||||
fx_model = FlaxSpeechEncoderDecoderModel(encoder_decoder_config)
|
||||
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
|
||||
|
||||
def check_equivalence_flax_to_pt(self, config, decoder_config, inputs_dict):
|
||||
|
||||
encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||
|
||||
pt_model = SpeechEncoderDecoderModel(encoder_decoder_config)
|
||||
fx_model = FlaxSpeechEncoderDecoderModel(encoder_decoder_config)
|
||||
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||
|
||||
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
|
||||
|
||||
def test_encoder_decoder_model_from_pretrained_configs(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_from_pretrained_configs(**input_ids_dict)
|
||||
|
||||
def test_encoder_decoder_model_from_pretrained(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_from_pretrained(**input_ids_dict, return_dict=False)
|
||||
|
||||
def test_encoder_decoder_model_from_pretrained_return_dict(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_from_pretrained(**input_ids_dict, return_dict=True)
|
||||
|
||||
def test_save_and_load_from_pretrained(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_save_and_load(**input_ids_dict)
|
||||
|
||||
def test_encoder_decoder_model_output_attentions(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
|
||||
|
||||
def test_encoder_decoder_model_generate(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_generate(**input_ids_dict)
|
||||
|
||||
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
|
||||
diff = np.abs((a - b)).max()
|
||||
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_pt_flax_equivalence(self):
|
||||
|
||||
config_inputs_dict = self.prepare_config_and_inputs()
|
||||
config = config_inputs_dict.pop("config")
|
||||
decoder_config = config_inputs_dict.pop("decoder_config")
|
||||
|
||||
inputs_dict = config_inputs_dict
|
||||
# `encoder_hidden_states` is not used in model call/forward
|
||||
del inputs_dict["encoder_hidden_states"]
|
||||
|
||||
# Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
|
||||
batch_size = inputs_dict["decoder_attention_mask"].shape[0]
|
||||
inputs_dict["decoder_attention_mask"] = np.concatenate(
|
||||
[np.ones(shape=(batch_size, 1)), inputs_dict["decoder_attention_mask"][:, 1:]], axis=1
|
||||
)
|
||||
|
||||
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
decoder_config.use_cache = False
|
||||
|
||||
self.assertTrue(decoder_config.cross_attention_hidden_size is None)
|
||||
|
||||
# check without `enc_to_dec_proj` projection
|
||||
decoder_config.hidden_size = config.hidden_size
|
||||
self.assertTrue(config.hidden_size == decoder_config.hidden_size)
|
||||
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
|
||||
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
|
||||
|
||||
# check `enc_to_dec_proj` work as expected
|
||||
decoder_config.hidden_size = decoder_config.hidden_size * 2
|
||||
self.assertTrue(config.hidden_size != decoder_config.hidden_size)
|
||||
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
|
||||
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
|
||||
|
||||
@slow
|
||||
def test_real_model_save_load_from_pretrained(self):
|
||||
model_2 = self.get_pretrained_model()
|
||||
inputs = ids_tensor([13, 5], model_2.config.encoder.vocab_size)
|
||||
decoder_input_ids = ids_tensor([13, 1], model_2.config.decoder.vocab_size)
|
||||
attention_mask = ids_tensor([13, 5], vocab_size=2)
|
||||
|
||||
outputs = model_2(
|
||||
inputs=inputs,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
out_2 = np.array(outputs[0])
|
||||
out_2[np.isnan(out_2)] = 0
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
model_2.save_pretrained(tmp_dirname)
|
||||
model_1 = FlaxSpeechEncoderDecoderModel.from_pretrained(tmp_dirname)
|
||||
|
||||
after_outputs = model_1(
|
||||
inputs=inputs,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
out_1 = np.array(after_outputs[0])
|
||||
out_1[np.isnan(out_1)] = 0
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
self.assertLessEqual(max_diff, 4e-2)
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxWav2Vec2GPT2ModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
|
||||
def get_pretrained_model_and_inputs(self):
|
||||
model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
"facebook/wav2vec2-large-lv60", "gpt2-medium"
|
||||
)
|
||||
batch_size = 13
|
||||
input_values = floats_tensor([batch_size, 512], model.config.encoder.vocab_size)
|
||||
attention_mask = random_attention_mask([batch_size, 512])
|
||||
decoder_input_ids = ids_tensor([batch_size, 4], model.config.decoder.vocab_size)
|
||||
decoder_attention_mask = random_attention_mask([batch_size, 4])
|
||||
inputs = {
|
||||
"inputs": input_values,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
}
|
||||
|
||||
return model, inputs
|
||||
|
||||
def get_encoder_decoder_model(self, config, decoder_config):
|
||||
encoder_model = FlaxWav2Vec2Model(config)
|
||||
decoder_model = FlaxGPT2LMHeadModel(decoder_config)
|
||||
return encoder_model, decoder_model
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
model_tester_encoder = FlaxWav2Vec2ModelTester(self, batch_size=13)
|
||||
model_tester_decoder = FlaxGPT2ModelTester(self, batch_size=13)
|
||||
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
|
||||
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
|
||||
(config, inputs, attention_mask) = encoder_config_and_inputs
|
||||
(
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
) = decoder_config_and_inputs
|
||||
|
||||
# make sure that cross attention layers are added
|
||||
decoder_config.add_cross_attention = True
|
||||
return {
|
||||
"config": config,
|
||||
"inputs": inputs,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_config": decoder_config,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
}
|
||||
|
||||
@slow
|
||||
def test_flaxwav2vec2gpt2_pt_flax_equivalence(self):
|
||||
pt_model = SpeechEncoderDecoderModel.from_pretrained("jsnfly/wav2vec2-large-xlsr-53-german-gpt2")
|
||||
fx_model = FlaxSpeechEncoderDecoderModel.from_pretrained(
|
||||
"jsnfly/wav2vec2-large-xlsr-53-german-gpt2", from_pt=True
|
||||
)
|
||||
|
||||
pt_model.to(torch_device)
|
||||
pt_model.eval()
|
||||
|
||||
# prepare inputs
|
||||
batch_size = 13
|
||||
input_values = floats_tensor([batch_size, 512], fx_model.config.encoder.vocab_size)
|
||||
attention_mask = random_attention_mask([batch_size, 512])
|
||||
decoder_input_ids = ids_tensor([batch_size, 4], fx_model.config.decoder.vocab_size)
|
||||
decoder_attention_mask = random_attention_mask([batch_size, 4])
|
||||
inputs_dict = {
|
||||
"inputs": input_values,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
}
|
||||
|
||||
flax_inputs = inputs_dict
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs)
|
||||
pt_logits = pt_outputs.logits
|
||||
pt_outputs = pt_outputs.to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**inputs_dict)
|
||||
fx_logits = fx_outputs.logits
|
||||
fx_outputs = fx_outputs.to_tuple()
|
||||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
self.assert_almost_equals(fx_logits, pt_logits.numpy(), 4e-2)
|
||||
|
||||
# PT -> Flax
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**inputs_dict)
|
||||
fx_logits_loaded = fx_outputs_loaded.logits
|
||||
fx_outputs_loaded = fx_outputs_loaded.to_tuple()
|
||||
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
self.assert_almost_equals(fx_logits_loaded, pt_logits.numpy(), 4e-2)
|
||||
|
||||
# Flax -> PT
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = SpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
pt_model_loaded.to(torch_device)
|
||||
pt_model_loaded.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
|
||||
pt_logits_loaded = pt_outputs_loaded.logits
|
||||
pt_outputs_loaded = pt_outputs_loaded.to_tuple()
|
||||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
|
||||
|
|
@ -215,6 +215,7 @@ def get_model_modules():
|
|||
"modeling_flax_encoder_decoder",
|
||||
"modeling_flax_utils",
|
||||
"modeling_speech_encoder_decoder",
|
||||
"modeling_flax_speech_encoder_decoder",
|
||||
"modeling_flax_vision_encoder_decoder",
|
||||
"modeling_transfo_xl_utilities",
|
||||
"modeling_tf_auto",
|
||||
|
|
@ -290,6 +291,7 @@ def get_model_test_files():
|
|||
"test_modeling_common",
|
||||
"test_modeling_encoder_decoder",
|
||||
"test_modeling_flax_encoder_decoder",
|
||||
"test_modeling_flax_speech_encoder_decoder",
|
||||
"test_modeling_marian",
|
||||
"test_modeling_tf_common",
|
||||
"test_modeling_tf_encoder_decoder",
|
||||
|
|
|
|||
Loading…
Reference in a new issue