mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Custom TF weights loading (#7422)
* First try * Fix TF utils * Handle authorized unexpected keys when loading weights * Add several more authorized unexpected keys * Apply style * Fix test * Address Patrick's comments. * Update src/transformers/modeling_tf_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Apply style * Make return_dict the default behavior and display a warning message * Revert * Replace wrong keyword * Revert code * Add forgot key * Fix bug in loading PT models from a TF one. * Fix sort * Add a test for custom load weights in BERT * Apply style * Remove unused import Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
d3adb985d1
commit
9cf7b23b9b
4 changed files with 139 additions and 27 deletions
|
|
@ -854,6 +854,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
|
|||
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
|
||||
class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||
|
||||
authorized_unexpected_keys = [r"pooler"]
|
||||
authorized_missing_keys = [r"pooler"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
|
|
@ -939,6 +940,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
|||
|
||||
class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
|
||||
authorized_unexpected_keys = [r"pooler"]
|
||||
authorized_missing_keys = [r"pooler"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
|
|
@ -1286,6 +1288,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
|
|||
)
|
||||
class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss):
|
||||
|
||||
authorized_unexpected_keys = [r"pooler"]
|
||||
authorized_missing_keys = [r"pooler"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
|
|
@ -1369,6 +1372,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
|
|||
)
|
||||
class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss):
|
||||
|
||||
authorized_unexpected_keys = [r"pooler"]
|
||||
authorized_missing_keys = [r"pooler"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
|
|
|
|||
|
|
@ -177,6 +177,13 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
|
|||
elif len(symbolic_weight.shape) > len(array.shape):
|
||||
array = numpy.expand_dims(array, axis=0)
|
||||
|
||||
if list(symbolic_weight.shape) != list(array.shape):
|
||||
try:
|
||||
array = numpy.reshape(array, symbolic_weight.shape)
|
||||
except AssertionError as e:
|
||||
e.args += (symbolic_weight.shape, array.shape)
|
||||
raise e
|
||||
|
||||
try:
|
||||
assert list(symbolic_weight.shape) == list(array.shape)
|
||||
except AssertionError as e:
|
||||
|
|
@ -251,6 +258,8 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs
|
|||
|
||||
import transformers
|
||||
|
||||
from .modeling_tf_utils import load_tf_weights
|
||||
|
||||
logger.info("Loading TensorFlow weights from {}".format(tf_checkpoint_path))
|
||||
|
||||
# Instantiate and load the associated TF 2.0 model
|
||||
|
|
@ -264,7 +273,7 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs
|
|||
if tf_inputs is not None:
|
||||
tf_model(tf_inputs, training=False) # Make sure model is built
|
||||
|
||||
tf_model.load_weights(tf_checkpoint_path, by_name=True)
|
||||
load_tf_weights(tf_model, tf_checkpoint_path)
|
||||
|
||||
return load_tf2_model_in_pytorch_model(pt_model, tf_model, allow_missing_keys=allow_missing_keys)
|
||||
|
||||
|
|
@ -332,6 +341,13 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
|
|||
elif len(pt_weight.shape) > len(array.shape):
|
||||
array = numpy.expand_dims(array, axis=0)
|
||||
|
||||
if list(pt_weight.shape) != list(array.shape):
|
||||
try:
|
||||
array = numpy.reshape(array, pt_weight.shape)
|
||||
except AssertionError as e:
|
||||
e.args += (pt_weight.shape, array.shape)
|
||||
raise e
|
||||
|
||||
try:
|
||||
assert list(pt_weight.shape) == list(array.shape)
|
||||
except AssertionError as e:
|
||||
|
|
|
|||
|
|
@ -23,12 +23,12 @@ from typing import Dict, List, Optional, Union
|
|||
import h5py
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras.saving import hdf5_format
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .file_utils import DUMMY_INPUTS, TF2_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url
|
||||
from .generation_tf_utils import TFGenerationMixin
|
||||
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
|
||||
from .utils import logging
|
||||
|
||||
|
||||
|
|
@ -216,6 +216,91 @@ class TFMaskedLanguageModelingLoss(TFCausalLanguageModelingLoss):
|
|||
"""
|
||||
|
||||
|
||||
def detect_tf_missing_unexpected_layers(model, resolved_archive_file):
|
||||
"""
|
||||
Detect missing and unexpected layers.
|
||||
|
||||
Args:
|
||||
model (:obj:`tf.keras.models.Model`):
|
||||
The model to load the weights into.
|
||||
resolved_archive_file (:obj:`str`):
|
||||
The location of the H5 file.
|
||||
|
||||
Returns:
|
||||
Two lists, one for the missing layers, and another one for the unexpected layers.
|
||||
"""
|
||||
missing_layers = []
|
||||
unexpected_layers = []
|
||||
|
||||
with h5py.File(resolved_archive_file, "r") as f:
|
||||
saved_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
|
||||
model_layer_names = set(layer.name for layer in model.layers)
|
||||
missing_layers = list(model_layer_names - saved_layer_names)
|
||||
unexpected_layers = list(saved_layer_names - model_layer_names)
|
||||
|
||||
for layer in model.layers:
|
||||
if layer.name in saved_layer_names:
|
||||
g = f[layer.name]
|
||||
saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(g, "weight_names")
|
||||
saved_weight_names_set = set(
|
||||
"/".join(weight_name.split("/")[2:]) for weight_name in saved_weight_names
|
||||
)
|
||||
symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
|
||||
symbolic_weights_names = set(
|
||||
"/".join(symbolic_weight.name.split("/")[2:]) for symbolic_weight in symbolic_weights
|
||||
)
|
||||
missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set))
|
||||
unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names))
|
||||
|
||||
return missing_layers, unexpected_layers
|
||||
|
||||
|
||||
def load_tf_weights(model, resolved_archive_file):
|
||||
"""
|
||||
Load the TF weights from a H5 file.
|
||||
|
||||
Args:
|
||||
model (:obj:`tf.keras.models.Model`):
|
||||
The model to load the weights into.
|
||||
resolved_archive_file (:obj:`str`):
|
||||
The location of the H5 file.
|
||||
"""
|
||||
with h5py.File(resolved_archive_file, "r") as f:
|
||||
saved_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
|
||||
weight_value_tuples = []
|
||||
|
||||
for layer in model.layers:
|
||||
if layer.name in saved_layer_names:
|
||||
g = f[layer.name]
|
||||
saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(g, "weight_names")
|
||||
symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
|
||||
saved_weight_names_values = {}
|
||||
|
||||
for weight_name in saved_weight_names:
|
||||
name = "/".join(weight_name.split("/")[1:])
|
||||
saved_weight_names_values[name] = np.asarray(g[weight_name])
|
||||
|
||||
for symbolic_weight in symbolic_weights:
|
||||
splited_layers = symbolic_weight.name.split("/")[1:]
|
||||
symbolic_weight_name = "/".join(splited_layers)
|
||||
|
||||
if symbolic_weight_name in saved_weight_names_values:
|
||||
saved_weight_value = saved_weight_names_values[symbolic_weight_name]
|
||||
|
||||
if K.int_shape(symbolic_weight) != saved_weight_value.shape:
|
||||
try:
|
||||
array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight))
|
||||
except AssertionError as e:
|
||||
e.args += (K.int_shape(symbolic_weight), saved_weight_value.shape)
|
||||
raise e
|
||||
else:
|
||||
array = saved_weight_value
|
||||
|
||||
weight_value_tuples.append((symbolic_weight, array))
|
||||
|
||||
K.batch_set_value(weight_value_tuples)
|
||||
|
||||
|
||||
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
||||
r"""
|
||||
Base class for all TF models.
|
||||
|
|
@ -231,10 +316,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
|||
:class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
|
||||
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
|
||||
derived classes of the same architecture adding modules on top of the base model.
|
||||
- **authorized_missing_keys** (:obj:`List[str]`, `optional`) -- A list of re pattern of tensor names to ignore
|
||||
from the model when loading the model weights (and avoid unnecessary warnings).
|
||||
- **authorized_unexpected_keys** (:obj:`List[str]`, `optional`) -- A list of re pattern of tensor names to ignore
|
||||
from the weights when loading the model weights (and avoid unnecessary warnings).
|
||||
"""
|
||||
config_class = None
|
||||
base_model_prefix = ""
|
||||
authorized_missing_keys = None
|
||||
authorized_unexpected_keys = None
|
||||
|
||||
@property
|
||||
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
|
||||
|
|
@ -604,6 +694,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
|||
model = cls(config, *model_args, **model_kwargs)
|
||||
|
||||
if from_pt:
|
||||
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
|
||||
|
||||
# Load from a PyTorch checkpoint
|
||||
return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file, allow_missing_keys=True)
|
||||
|
||||
|
|
@ -613,7 +705,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
|||
# 'by_name' allow us to do transfer learning by skipping/adding layers
|
||||
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
|
||||
try:
|
||||
model.load_weights(resolved_archive_file, by_name=True)
|
||||
load_tf_weights(model, resolved_archive_file)
|
||||
except OSError:
|
||||
raise OSError(
|
||||
"Unable to load weights from h5 file. "
|
||||
|
|
@ -622,23 +714,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
|||
|
||||
model(model.dummy_inputs, training=False) # Make sure restore ops are run
|
||||
|
||||
# Check if the models are the same to output loading informations
|
||||
with h5py.File(resolved_archive_file, "r") as f:
|
||||
if "layer_names" not in f.attrs and "model_weights" in f:
|
||||
f = f["model_weights"]
|
||||
hdf5_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
|
||||
model_layer_names = set(layer.name for layer in model.layers)
|
||||
missing_keys = list(model_layer_names - hdf5_layer_names)
|
||||
unexpected_keys = list(hdf5_layer_names - model_layer_names)
|
||||
error_msgs = []
|
||||
missing_keys, unexpected_keys = detect_tf_missing_unexpected_layers(model, resolved_archive_file)
|
||||
|
||||
if cls.authorized_missing_keys is not None:
|
||||
for pat in cls.authorized_missing_keys:
|
||||
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
|
||||
|
||||
if cls.authorized_unexpected_keys is not None:
|
||||
for pat in cls.authorized_unexpected_keys:
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
|
||||
f"Some layers from the model checkpoint at {pretrained_model_name_or_path} were not used when "
|
||||
f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
|
||||
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
|
||||
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n"
|
||||
|
|
@ -646,25 +734,24 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
|||
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
|
||||
)
|
||||
else:
|
||||
logger.warning(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
||||
logger.warning(f"All model checkpoint layers were used when initializing {model.__class__.__name__}.\n")
|
||||
|
||||
if len(missing_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
|
||||
f"Some layers of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
|
||||
f"and are newly initialized: {missing_keys}\n"
|
||||
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
|
||||
f"All the layers of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
|
||||
f"If your task is similar to the task the model of the checkpoint was trained on, "
|
||||
f"you can already use {model.__class__.__name__} for predictions without further training."
|
||||
)
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError(
|
||||
"Error(s) in loading weights for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
|
||||
)
|
||||
|
||||
if output_loading_info:
|
||||
loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs}
|
||||
loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys}
|
||||
|
||||
return model, loading_info
|
||||
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@
|
|||
import unittest
|
||||
|
||||
from transformers import BertConfig, is_tf_available
|
||||
from transformers.testing_utils import require_tf, slow
|
||||
from transformers.testing_utils import require_tf
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
|
|
@ -317,9 +317,14 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_bert_for_token_classification(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
# for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
for model_name in ["bert-base-uncased"]:
|
||||
model = TFBertModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
model = TFBertModel.from_pretrained("jplu/tiny-tf-bert-random")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_custom_load_tf_weights(self):
|
||||
model, output_loading_info = TFBertForTokenClassification.from_pretrained(
|
||||
"jplu/tiny-tf-bert-random", use_cdn=False, output_loading_info=True
|
||||
)
|
||||
self.assertEqual(sorted(output_loading_info["unexpected_keys"]), ["mlm___cls", "nsp___cls"])
|
||||
for layer in output_loading_info["missing_keys"]:
|
||||
self.assertTrue(layer.split("_")[0] in ["dropout", "classifier"])
|
||||
|
|
|
|||
Loading…
Reference in a new issue