From c95dc4ebe7b2fcf2981edd27ddf3461afa49ac98 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 7 Feb 2025 11:39:28 +0100 Subject: [PATCH] update --- src/transformers/modeling_utils.py | 8 ++++---- tests/test_modeling_common.py | 4 +++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 81902d02b..ed17c0010 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1476,12 +1476,12 @@ def _find_missing_and_unexpected_keys( # Model-specific exceptions for missing and unexpected keys (e.g. if the modeling change over time, or any other reason...) if cls._keys_to_ignore_on_load_missing is not None: - for pat in cls._keys_to_ignore_on_load_missing: - missing_keys = [k for k in missing_keys if re.search(pat, k) is None] + for pattern in cls._keys_to_ignore_on_load_missing: + missing_keys = [k for k in missing_keys if re.search(pattern, k) is None] if cls._keys_to_ignore_on_load_unexpected is not None: - for pat in cls._keys_to_ignore_on_load_unexpected: - unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + for pattern in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pattern, k) is None] return missing_keys, unexpected_keys diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 2e2db5b11..c8624124e 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2358,7 +2358,9 @@ class ModelTesterMixin: if model_reloaded._keys_to_ignore_on_load_missing is None: expected_missing = set() else: - expected_missing = set(model_reloaded._keys_to_ignore_on_load_missing) + expected_missing = set() + for pattern in model_reloaded._keys_to_ignore_on_load_missing: + expected_missing.update({k for k in model_reloaded.state_dict().keys() if re.search(pattern, k) is not None}) self.assertEqual( missed_missing, expected_missing,