mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Improve TF weight loading, especially PT crossloading (#21792)
* First commit for the improved PT-TF weight loading * Remove workarounds from TFEncoderDecoder tests * Allow a custom weight renaming function in from_pretrained and use that to clean up EncoderDecoder * make fixup * First attempt at visionencoderdecoder * Disable tensorfloat32 in tests to get consistent outputs * Quick fix to tf_vision_encoder_decoder tests * make fixup * Update Blenderbot tests * Remove unused arg in modeling_tf_opt * load_tf_sharded_weights had strict=True! This meant transfer learning was impossible, so I'm setting it to False. * Support prefixes when loading sharded TF checkpoints * make fixup * Add test to load sharded models with a weight prefix * Fix sharded weight loading test * Add a test for transfer from a sharded checkpoint * make fixup * Add test to check that crossloading from PT with a prefix works * Refactor from_pretrained in the encoderdecoder classes * Refactor from_pretrained in the encoderdecoder classes * missmatched -> mismatched * Explicitly check for None * No comments showing my very impressive and attractive knowledge of Py3.9+ * Disable TF32 across all TF tests
This commit is contained in:
parent
871c31a6f1
commit
acfb714bdf
7 changed files with 147 additions and 148 deletions
|
|
@ -39,7 +39,9 @@ class TransposeType(ExplicitEnum):
|
|||
CONV2D = "conv2d"
|
||||
|
||||
|
||||
def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="", tf_weight_shape=None):
|
||||
def convert_tf_weight_name_to_pt_weight_name(
|
||||
tf_name, start_prefix_to_remove="", tf_weight_shape=None, name_scope=None
|
||||
):
|
||||
"""
|
||||
Convert a TF 2.0 model variable name in a pytorch model weight name.
|
||||
|
||||
|
|
@ -54,6 +56,14 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="",
|
|||
- transpose: `TransposeType` member indicating whether and how TF2.0 and PyTorch weights matrices should be
|
||||
transposed with regards to each other
|
||||
"""
|
||||
if name_scope is not None:
|
||||
if not tf_name.startswith(name_scope):
|
||||
raise ValueError(
|
||||
f"Weight name {tf_name} does not start with name_scope {name_scope}. This is an internal error "
|
||||
"in Transformers, so (unless you were doing something really evil) please open an issue to report it!"
|
||||
)
|
||||
tf_name = tf_name[len(name_scope) :]
|
||||
tf_name = tf_name.lstrip("/")
|
||||
tf_name = tf_name.replace(":0", "") # device ids
|
||||
tf_name = re.sub(
|
||||
r"/[^/]*___([^/]*)/", r"/\1/", tf_name
|
||||
|
|
@ -144,7 +154,13 @@ def apply_transpose(transpose: TransposeType, weight, match_shape=None, pt_to_tf
|
|||
|
||||
|
||||
def load_pytorch_checkpoint_in_tf2_model(
|
||||
tf_model, pytorch_checkpoint_path, tf_inputs=None, allow_missing_keys=False, output_loading_info=False
|
||||
tf_model,
|
||||
pytorch_checkpoint_path,
|
||||
tf_inputs=None,
|
||||
allow_missing_keys=False,
|
||||
output_loading_info=False,
|
||||
_prefix=None,
|
||||
tf_to_pt_weight_rename=None,
|
||||
):
|
||||
"""Load pytorch checkpoints in a TF 2.0 model"""
|
||||
try:
|
||||
|
|
@ -176,6 +192,8 @@ def load_pytorch_checkpoint_in_tf2_model(
|
|||
tf_inputs=tf_inputs,
|
||||
allow_missing_keys=allow_missing_keys,
|
||||
output_loading_info=output_loading_info,
|
||||
_prefix=_prefix,
|
||||
tf_to_pt_weight_rename=tf_to_pt_weight_rename,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -189,7 +207,13 @@ def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=None, allow_mi
|
|||
|
||||
|
||||
def load_pytorch_weights_in_tf2_model(
|
||||
tf_model, pt_state_dict, tf_inputs=None, allow_missing_keys=False, output_loading_info=False
|
||||
tf_model,
|
||||
pt_state_dict,
|
||||
tf_inputs=None,
|
||||
allow_missing_keys=False,
|
||||
output_loading_info=False,
|
||||
_prefix=None,
|
||||
tf_to_pt_weight_rename=None,
|
||||
):
|
||||
"""Load pytorch state_dict in a TF 2.0 model."""
|
||||
try:
|
||||
|
|
@ -209,11 +233,19 @@ def load_pytorch_weights_in_tf2_model(
|
|||
tf_inputs=tf_inputs,
|
||||
allow_missing_keys=allow_missing_keys,
|
||||
output_loading_info=output_loading_info,
|
||||
_prefix=_prefix,
|
||||
tf_to_pt_weight_rename=tf_to_pt_weight_rename,
|
||||
)
|
||||
|
||||
|
||||
def load_pytorch_state_dict_in_tf2_model(
|
||||
tf_model, pt_state_dict, tf_inputs=None, allow_missing_keys=False, output_loading_info=False
|
||||
tf_model,
|
||||
pt_state_dict,
|
||||
tf_inputs=None,
|
||||
allow_missing_keys=False,
|
||||
output_loading_info=False,
|
||||
_prefix=None,
|
||||
tf_to_pt_weight_rename=None,
|
||||
):
|
||||
"""Load a pytorch state_dict in a TF 2.0 model."""
|
||||
import tensorflow as tf
|
||||
|
|
@ -227,8 +259,11 @@ def load_pytorch_state_dict_in_tf2_model(
|
|||
if tf_inputs is None:
|
||||
tf_inputs = tf_model.dummy_inputs
|
||||
|
||||
if _prefix is None:
|
||||
_prefix = ""
|
||||
if tf_inputs is not None:
|
||||
tf_model(tf_inputs, training=False) # Make sure model is built
|
||||
with tf.name_scope(_prefix):
|
||||
tf_model(tf_inputs, training=False) # Make sure model is built
|
||||
# Adapt state dict - TODO remove this and update the AWS weights files instead
|
||||
# Convert old format to new format if needed from a PyTorch state_dict
|
||||
old_keys = []
|
||||
|
|
@ -249,8 +284,10 @@ def load_pytorch_state_dict_in_tf2_model(
|
|||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
pt_state_dict[new_key] = pt_state_dict.pop(old_key)
|
||||
|
||||
# Make sure we are able to load PyTorch base models as well as derived models (with heads)
|
||||
# TF models always have a prefix, some of PyTorch models (base ones) don't
|
||||
# Matt: All TF models store the actual model stem in a MainLayer class, including the base model.
|
||||
# In PT, the derived models (with heads) use the base model class as the stem instead, and the base model
|
||||
# just contains the stem itself, and there is no MainLayer class. This means that TF base classes have one
|
||||
# extra layer in their weight names, corresponding to the MainLayer class. This code block compensates for that.
|
||||
start_prefix_to_remove = ""
|
||||
if not any(s.startswith(tf_model.base_model_prefix) for s in pt_state_dict.keys()):
|
||||
start_prefix_to_remove = tf_model.base_model_prefix + "."
|
||||
|
|
@ -263,8 +300,13 @@ def load_pytorch_state_dict_in_tf2_model(
|
|||
for symbolic_weight in symbolic_weights:
|
||||
sw_name = symbolic_weight.name
|
||||
name, transpose = convert_tf_weight_name_to_pt_weight_name(
|
||||
sw_name, start_prefix_to_remove=start_prefix_to_remove, tf_weight_shape=symbolic_weight.shape
|
||||
sw_name,
|
||||
start_prefix_to_remove=start_prefix_to_remove,
|
||||
tf_weight_shape=symbolic_weight.shape,
|
||||
name_scope=_prefix,
|
||||
)
|
||||
if tf_to_pt_weight_rename is not None:
|
||||
name = tf_to_pt_weight_rename(name)
|
||||
|
||||
# Find associated numpy array in pytorch model state dict
|
||||
if name not in pt_state_dict:
|
||||
|
|
|
|||
|
|
@ -707,7 +707,7 @@ def tf_shard_checkpoint(weights, max_shard_size="10GB"):
|
|||
return shards, index
|
||||
|
||||
|
||||
def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, strict=True):
|
||||
def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, strict=False, _prefix=None):
|
||||
"""
|
||||
This is the same as `load_tf_weights` but for a sharded checkpoint. Detect missing and unexpected layers and load
|
||||
the TF weights from the shard file accordingly to their names and shapes.
|
||||
|
|
@ -729,32 +729,35 @@ def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, s
|
|||
"""
|
||||
|
||||
# Load the index
|
||||
missing_keys = []
|
||||
unexpected_keys = set()
|
||||
saved_keys = set()
|
||||
missmatched_keys = set()
|
||||
mismatched_keys = set()
|
||||
|
||||
# Since TF adds the name of the class to its weights, and uses the index and not the name of the layer to load
|
||||
# the weight, we have to get rid of the first prefix of the name of the layer.
|
||||
model_keys = set()
|
||||
model_layer_map = {}
|
||||
for i, k in enumerate(model.weights):
|
||||
if "model." in k.name or len(k.name.split("/")) == 1:
|
||||
layer_name = k.name
|
||||
else:
|
||||
layer_name = "/".join(k.name.split("/")[1:])
|
||||
layer_name = k.name
|
||||
if _prefix is not None and layer_name.startswith(_prefix):
|
||||
layer_name = layer_name[len(_prefix) :]
|
||||
layer_name = layer_name.lstrip("/")
|
||||
if not ("model." in layer_name or len(layer_name.split("/")) == 1):
|
||||
layer_name = "/".join(layer_name.split("/")[1:])
|
||||
model_keys.add(layer_name)
|
||||
model_layer_map[layer_name] = i
|
||||
|
||||
for shard_file in shard_files:
|
||||
state_dict = tf.io.read_file(shard_file)
|
||||
saved_weight_names_set, unexpected_keys_set, missmatched_keys_set = load_tf_shard(
|
||||
model, model_layer_map, shard_file, ignore_mismatched_sizes=ignore_mismatched_sizes
|
||||
saved_weight_names_set, unexpected_keys_set, mismatched_keys_set = load_tf_shard(
|
||||
model,
|
||||
model_layer_map,
|
||||
shard_file,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
_prefix=_prefix,
|
||||
)
|
||||
saved_keys.update(saved_weight_names_set)
|
||||
unexpected_keys.update(unexpected_keys_set)
|
||||
missmatched_keys.update(missmatched_keys_set)
|
||||
del state_dict
|
||||
mismatched_keys.update(mismatched_keys_set)
|
||||
gc.collect()
|
||||
|
||||
missing_keys = model_keys - saved_keys
|
||||
|
|
@ -768,10 +771,10 @@ def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, s
|
|||
error_message += f"\nMissing key(s): {str_unexpected_keys}."
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
return missing_keys, unexpected_keys, missmatched_keys
|
||||
return missing_keys, unexpected_keys, mismatched_keys
|
||||
|
||||
|
||||
def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatched_sizes=False):
|
||||
def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
|
||||
"""
|
||||
Loads a shard from a sharded checkpoint file. Handles the missing keys and unexpected keys.
|
||||
|
||||
|
|
@ -783,11 +786,11 @@ def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatch
|
|||
|
||||
Returns:
|
||||
`tf.keras.models.Model`: Three lists, one for the layers that were found and succesfully restored (from the
|
||||
shard file), one for the missmatched layers, and another one for the unexpected layers.
|
||||
shard file), one for the mismatched layers, and another one for the unexpected layers.
|
||||
"""
|
||||
saved_weight_names_set = set()
|
||||
saved_weights = {}
|
||||
missmatched_keys = set()
|
||||
mismatched_keys = set()
|
||||
unexpected_keys = set()
|
||||
# Read the H5 file
|
||||
try:
|
||||
|
|
@ -822,7 +825,7 @@ def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatch
|
|||
array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight))
|
||||
except ValueError as e:
|
||||
if ignore_mismatched_sizes:
|
||||
missmatched_keys.add(
|
||||
mismatched_keys.add(
|
||||
(layer_name, saved_weight_value.shape, K.int_shape(symbolic_weight))
|
||||
)
|
||||
continue
|
||||
|
|
@ -836,7 +839,7 @@ def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatch
|
|||
|
||||
K.batch_set_value(weight_value_tuples)
|
||||
|
||||
return saved_weight_names_set, unexpected_keys, missmatched_keys
|
||||
return saved_weight_names_set, unexpected_keys, mismatched_keys
|
||||
|
||||
except Exception as e:
|
||||
try:
|
||||
|
|
@ -2458,6 +2461,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
||||
specify the folder name here.
|
||||
tf_to_pt_weight_rename (`Callable`, *optional*):
|
||||
A function that is called to transform the names of weights during the PyTorch to TensorFlow
|
||||
crossloading process. This is not necessary for most models, but is useful to allow composite models to
|
||||
be crossloaded correctly.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
|
||||
|
|
@ -2506,6 +2513,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||
from_auto_class = kwargs.pop("_from_auto", False)
|
||||
subfolder = kwargs.pop("subfolder", "")
|
||||
commit_hash = kwargs.pop("_commit_hash", None)
|
||||
tf_to_pt_weight_rename = kwargs.pop("tf_to_pt_weight_rename", None)
|
||||
|
||||
if trust_remote_code is True:
|
||||
logger.warning(
|
||||
|
|
@ -2745,7 +2753,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||
|
||||
# Load from a PyTorch checkpoint
|
||||
return load_pytorch_checkpoint_in_tf2_model(
|
||||
model, resolved_archive_file, allow_missing_keys=True, output_loading_info=output_loading_info
|
||||
model,
|
||||
resolved_archive_file,
|
||||
allow_missing_keys=True,
|
||||
output_loading_info=output_loading_info,
|
||||
_prefix=load_weight_prefix,
|
||||
tf_to_pt_weight_rename=tf_to_pt_weight_rename,
|
||||
)
|
||||
|
||||
# we might need to extend the variable scope for composite models
|
||||
|
|
@ -2761,7 +2774,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||
state_dict = safe_load_file(resolved_archive_file)
|
||||
# Load from a PyTorch checkpoint
|
||||
return load_pytorch_state_dict_in_tf2_model(
|
||||
model, state_dict, allow_missing_keys=True, output_loading_info=output_loading_info
|
||||
model,
|
||||
state_dict,
|
||||
allow_missing_keys=True,
|
||||
output_loading_info=output_loading_info,
|
||||
_prefix=load_weight_prefix,
|
||||
)
|
||||
|
||||
# 'by_name' allow us to do transfer learning by skipping/adding layers
|
||||
|
|
@ -2775,6 +2792,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||
model,
|
||||
resolved_archive_file,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
_prefix=load_weight_prefix,
|
||||
)
|
||||
else:
|
||||
missing_keys, unexpected_keys, mismatched_keys = load_tf_weights(
|
||||
|
|
|
|||
|
|
@ -15,9 +15,7 @@
|
|||
""" Classes to support TF Encoder-Decoder architectures"""
|
||||
|
||||
|
||||
import gc
|
||||
import os
|
||||
import tempfile
|
||||
import re
|
||||
import warnings
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
|
|
@ -306,46 +304,23 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||
|
||||
>>> model = TFEncoderDecoderModel.from_pretrained("ydshieh/bert2bert-cnn_dailymail-fp16")
|
||||
```"""
|
||||
# Matt: The TF and PT weights don't align because our TF base classes have an extra layer compared to PT models
|
||||
# (the main model stem is in the MainLayer class). If we remove that layer, then weight names sync up as normal.
|
||||
# However, the name of that extra layer is the name of the MainLayer in the base model. We make the assumption
|
||||
# here that the config model_type is the same as the name of the MainLayer. I don't know of anywhere that's
|
||||
# not the case, and I wasn't sure how else to go from the config to the correct MainLayer name!
|
||||
|
||||
from_pt = kwargs.pop("from_pt", False)
|
||||
if from_pt:
|
||||
import torch
|
||||
if kwargs.get("from_pt", False):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
|
||||
encoder_model_type = config.encoder.model_type
|
||||
|
||||
from transformers import EncoderDecoderModel
|
||||
|
||||
# a workaround to load from pytorch checkpoint
|
||||
_model = EncoderDecoderModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
config = _model.config
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
encoder_dir = os.path.join(tmpdirname, "encoder")
|
||||
decoder_dir = os.path.join(tmpdirname, "decoder")
|
||||
_model.encoder.save_pretrained(encoder_dir)
|
||||
_model.decoder.save_pretrained(decoder_dir)
|
||||
|
||||
if hasattr(_model, "enc_to_dec_proj"):
|
||||
enc_to_dec_proj_kernel = tf.transpose(
|
||||
tf.constant(_model.enc_to_dec_proj.weight.detach().to("cpu").numpy()), perm=(1, 0)
|
||||
)
|
||||
enc_to_dec_proj_bias = tf.constant(_model.enc_to_dec_proj.bias.detach().to("cpu").numpy())
|
||||
|
||||
del _model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
model = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
encoder_dir, decoder_dir, encoder_from_pt=True, decoder_from_pt=True
|
||||
)
|
||||
# This is only for copying some specific attributes of this particular model.
|
||||
model.config = config
|
||||
|
||||
if hasattr(model, "enc_to_dec_proj"):
|
||||
model(model.dummy_inputs)
|
||||
model.enc_to_dec_proj.kernel.assign(enc_to_dec_proj_kernel)
|
||||
model.enc_to_dec_proj.bias.assign(enc_to_dec_proj_bias)
|
||||
|
||||
return model
|
||||
def tf_to_pt_weight_rename(tf_weight):
|
||||
if "encoder" in tf_weight and "decoder" not in tf_weight:
|
||||
return re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight)
|
||||
else:
|
||||
return tf_weight
|
||||
|
||||
kwargs["tf_to_pt_weight_rename"] = tf_to_pt_weight_rename
|
||||
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -451,14 +426,6 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||
kwargs_encoder["load_weight_prefix"] = cls.load_weight_prefix
|
||||
encoder = TFAutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
|
||||
|
||||
# This is necessary to make `from_pretrained` following `save_pretrained` work correctly
|
||||
if kwargs_encoder.get("from_pt", None):
|
||||
del kwargs_encoder["from_pt"]
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
encoder.save_pretrained(tmp_dirname)
|
||||
del encoder
|
||||
encoder = TFAutoModel.from_pretrained(tmp_dirname, *model_args, **kwargs_encoder)
|
||||
|
||||
decoder = kwargs_decoder.pop("model", None)
|
||||
if decoder is None:
|
||||
if decoder_pretrained_model_name_or_path is None:
|
||||
|
|
@ -493,14 +460,6 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||
kwargs_decoder["load_weight_prefix"] = cls.load_weight_prefix
|
||||
decoder = TFAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
||||
|
||||
# This is necessary to make `from_pretrained` following `save_pretrained` work correctly
|
||||
if kwargs_decoder.get("from_pt", None):
|
||||
del kwargs_decoder["from_pt"]
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
decoder.save_pretrained(tmp_dirname)
|
||||
del decoder
|
||||
decoder = TFAutoModelForCausalLM.from_pretrained(tmp_dirname, **kwargs_decoder)
|
||||
|
||||
# Make sure these 2 `tf.keras.Model` have fixed names so `from_pretrained` could load model weights correctly.
|
||||
if encoder.name != "encoder":
|
||||
raise ValueError("encoder model must be created with the name `encoder`.")
|
||||
|
|
|
|||
|
|
@ -486,7 +486,7 @@ OPT_INPUTS_DOCSTRING = r"""
|
|||
class TFOPTDecoder(tf.keras.layers.Layer):
|
||||
config_class = OPTConfig
|
||||
|
||||
def __init__(self, config: OPTConfig, load_weight_prefix=None, **kwargs):
|
||||
def __init__(self, config: OPTConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
|
|
|
|||
|
|
@ -15,9 +15,7 @@
|
|||
""" Classes to support TF Vision-Encoder-Text-Decoder architectures"""
|
||||
|
||||
|
||||
import gc
|
||||
import os
|
||||
import tempfile
|
||||
import re
|
||||
import warnings
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
|
|
@ -320,46 +318,23 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
|
|||
|
||||
>>> assert preds == ["a cat laying on top of a couch next to another cat"]
|
||||
```"""
|
||||
# Matt: The TF and PT weights don't align because our TF base classes have an extra layer compared to PT models
|
||||
# (the main model stem is in the MainLayer class). If we remove that layer, then weight names sync up as normal.
|
||||
# However, the name of that extra layer is the name of the MainLayer in the base model. We make the assumption
|
||||
# here that the config model_type is the same as the name of the MainLayer. I don't know of anywhere that's
|
||||
# not the case, and I wasn't sure how else to go from the config to the correct MainLayer name!
|
||||
|
||||
from_pt = kwargs.pop("from_pt", False)
|
||||
if from_pt:
|
||||
import torch
|
||||
if kwargs.get("from_pt", False):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
|
||||
encoder_model_type = config.encoder.model_type
|
||||
|
||||
from transformers import VisionEncoderDecoderModel
|
||||
|
||||
# a workaround to load from pytorch checkpoint
|
||||
_model = VisionEncoderDecoderModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
config = _model.config
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
encoder_dir = os.path.join(tmpdirname, "encoder")
|
||||
decoder_dir = os.path.join(tmpdirname, "decoder")
|
||||
_model.encoder.save_pretrained(encoder_dir)
|
||||
_model.decoder.save_pretrained(decoder_dir)
|
||||
|
||||
if hasattr(_model, "enc_to_dec_proj"):
|
||||
enc_to_dec_proj_kernel = tf.transpose(
|
||||
tf.constant(_model.enc_to_dec_proj.weight.detach().to("cpu").numpy()), perm=(1, 0)
|
||||
)
|
||||
enc_to_dec_proj_bias = tf.constant(_model.enc_to_dec_proj.bias.detach().to("cpu").numpy())
|
||||
|
||||
del _model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
model = TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
encoder_dir, decoder_dir, encoder_from_pt=True, decoder_from_pt=True
|
||||
)
|
||||
# This is only for copying some specific attributes of this particular model.
|
||||
model.config = config
|
||||
|
||||
if hasattr(model, "enc_to_dec_proj"):
|
||||
model(model.dummy_inputs)
|
||||
model.enc_to_dec_proj.kernel.assign(enc_to_dec_proj_kernel)
|
||||
model.enc_to_dec_proj.bias.assign(enc_to_dec_proj_bias)
|
||||
|
||||
return model
|
||||
def tf_to_pt_weight_rename(tf_weight):
|
||||
if "encoder" in tf_weight and "decoder" not in tf_weight:
|
||||
return re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight)
|
||||
else:
|
||||
return tf_weight
|
||||
|
||||
kwargs["tf_to_pt_weight_rename"] = tf_to_pt_weight_rename
|
||||
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -466,15 +441,6 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
|
|||
kwargs_encoder["load_weight_prefix"] = cls.load_weight_prefix
|
||||
encoder = TFAutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
|
||||
|
||||
# Necessary to make `save_pretrained -> from_pretrained` work correctly for the converted PT -> TF model.
|
||||
# See https://github.com/huggingface/transformers/pull/14016#issuecomment-944046313
|
||||
if kwargs_encoder.get("from_pt", None):
|
||||
del kwargs_encoder["from_pt"]
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
encoder.save_pretrained(tmp_dirname)
|
||||
del encoder
|
||||
encoder = TFAutoModel.from_pretrained(tmp_dirname, *model_args, **kwargs_encoder)
|
||||
|
||||
decoder = kwargs_decoder.pop("model", None)
|
||||
if decoder is None:
|
||||
if decoder_pretrained_model_name_or_path is None:
|
||||
|
|
@ -509,15 +475,6 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
|
|||
kwargs_decoder["load_weight_prefix"] = cls.load_weight_prefix
|
||||
decoder = TFAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
||||
|
||||
# Necessary to make `save_pretrained -> from_pretrained` work correctly for the converted PT -> TF model.
|
||||
# See https://github.com/huggingface/transformers/pull/14016#issuecomment-944046313
|
||||
if kwargs_decoder.get("from_pt", None):
|
||||
del kwargs_decoder["from_pt"]
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
decoder.save_pretrained(tmp_dirname)
|
||||
del decoder
|
||||
decoder = TFAutoModelForCausalLM.from_pretrained(tmp_dirname, **kwargs_decoder)
|
||||
|
||||
# Make sure these 2 `tf.keras.Model` have fixed names so `from_pretrained` could load model weights correctly.
|
||||
if encoder.name != "encoder":
|
||||
raise ValueError("encoder model must be created with the name `encoder`.")
|
||||
|
|
|
|||
|
|
@ -925,16 +925,14 @@ class TFViT2GPT2ModelIntegrationTest(unittest.TestCase):
|
|||
self.assertLessEqual(max_diff, 1e-4)
|
||||
|
||||
def generate_step(pixel_values):
|
||||
outputs = model.generate(
|
||||
pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True, output_scores=True
|
||||
)
|
||||
outputs = model.generate(pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True)
|
||||
output_ids = outputs.sequences
|
||||
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
preds = [pred.strip() for pred in preds]
|
||||
|
||||
return preds, outputs.scores.numpy()
|
||||
return preds
|
||||
|
||||
preds, scores = generate_step(pixel_values)
|
||||
preds = generate_step(pixel_values)
|
||||
|
||||
# should produce
|
||||
# ["a cat laying on top of a couch next to another cat"]
|
||||
|
|
|
|||
|
|
@ -90,6 +90,7 @@ if is_tf_available():
|
|||
TFAutoModel,
|
||||
TFAutoModelForSequenceClassification,
|
||||
TFBertForMaskedLM,
|
||||
TFBertForSequenceClassification,
|
||||
TFBertModel,
|
||||
TFRagModel,
|
||||
TFSharedEmbeddings,
|
||||
|
|
@ -107,6 +108,8 @@ if is_tf_available():
|
|||
from transformers.modeling_tf_utils import tf_shard_checkpoint, unpack_inputs
|
||||
from transformers.tf_utils import stable_softmax
|
||||
|
||||
tf.config.experimental.enable_tensor_float_32_execution(False)
|
||||
|
||||
if _tf_gpu_memory_limit is not None:
|
||||
gpus = tf.config.list_physical_devices("GPU")
|
||||
for gpu in gpus:
|
||||
|
|
@ -2140,6 +2143,18 @@ class UtilsFunctionsTest(unittest.TestCase):
|
|||
for p1, p2 in zip(model.weights, ref_model.weights):
|
||||
assert np.allclose(p1.numpy(), p2.numpy())
|
||||
|
||||
def test_sharded_checkpoint_with_prefix(self):
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", load_weight_prefix="a/b")
|
||||
sharded_model = TFBertModel.from_pretrained("ArthurZ/tiny-random-bert-sharded", load_weight_prefix="a/b")
|
||||
for p1, p2 in zip(model.weights, sharded_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
self.assertTrue(p1.name.startswith("a/b/"))
|
||||
self.assertTrue(p2.name.startswith("a/b/"))
|
||||
|
||||
def test_sharded_checkpoint_transfer(self):
|
||||
# If this doesn't throw an error then the test passes
|
||||
TFBertForSequenceClassification.from_pretrained("ArthurZ/tiny-random-bert-sharded")
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_checkpoint_sharding_local_from_pt(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
|
|
@ -2150,6 +2165,16 @@ class UtilsFunctionsTest(unittest.TestCase):
|
|||
for p1, p2 in zip(model.weights, ref_model.weights):
|
||||
assert np.allclose(p1.numpy(), p2.numpy())
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_checkpoint_loading_with_prefix_from_pt(self):
|
||||
model = TFBertModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-bert", from_pt=True, load_weight_prefix="a/b"
|
||||
)
|
||||
ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", from_pt=True)
|
||||
for p1, p2 in zip(model.weights, ref_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
self.assertTrue(p1.name.startswith("a/b/"))
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_checkpoint_sharding_hub_from_pt(self):
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True)
|
||||
|
|
|
|||
Loading…
Reference in a new issue