From c49cd927f7b59fc1308ff8073bde31ddeb15eed2 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 28 Jul 2020 18:29:35 -0400 Subject: [PATCH] [Fix] position_ids tests again (#6100) --- src/transformers/modeling_bert.py | 3 +-- tests/test_modeling_auto.py | 6 ++++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_bert.py b/src/transformers/modeling_bert.py index 757eb7c9c..850cae298 100644 --- a/src/transformers/modeling_bert.py +++ b/src/transformers/modeling_bert.py @@ -568,6 +568,7 @@ class BertPreTrainedModel(PreTrainedModel): config_class = BertConfig load_tf_weights = load_tf_weights_in_bert base_model_prefix = "bert" + authorized_missing_keys = [r"position_ids"] def _init_weights(self, module): """ Initialize the weights """ @@ -699,8 +700,6 @@ class BertModel(BertPreTrainedModel): """ - authorized_missing_keys = [r"position_ids"] - def __init__(self, config): super().__init__(config) self.config = config diff --git a/tests/test_modeling_auto.py b/tests/test_modeling_auto.py index cbbf857bc..b86506a01 100644 --- a/tests/test_modeling_auto.py +++ b/tests/test_modeling_auto.py @@ -88,9 +88,11 @@ class AutoModelTest(unittest.TestCase): model, loading_info = AutoModelForPreTraining.from_pretrained(model_name, output_loading_info=True) self.assertIsNotNone(model) self.assertIsInstance(model, BertForPreTraining) + # Only one value should not be initialized and in the missing keys. + missing_keys = loading_info.pop("missing_keys") + self.assertListEqual(["cls.predictions.decoder.bias"], missing_keys) for key, value in loading_info.items(): - # Only one value should not be initialized and in the missing keys. - self.assertEqual(len(value), 1 if key == "missing_keys" else 0) + self.assertEqual(len(value), 0) @slow def test_lmhead_model_from_pretrained(self):