diff --git a/test/test_serialization.py b/test/test_serialization.py index f22331831c3..9890d9b2d9f 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -1040,8 +1040,14 @@ class TestSerialization(TestCase, SerializationMixin): self.assertIsNone(torch.load(f, weights_only=False)) f.seek(0) # Safe load should assert - with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported global: GLOBAL __builtin__.print"): + with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported global: GLOBAL builtins.print"): torch.load(f, weights_only=True) + try: + torch.serialization.add_safe_globals([print]) + f.seek(0) + torch.load(f, weights_only=True) + finally: + torch.serialization.clear_safe_globals() @parametrize('weights_only', (False, True)) def test_serialization_math_bits(self, weights_only): diff --git a/torch/_weights_only_unpickler.py b/torch/_weights_only_unpickler.py index 2ca07d15136..cf1514dcf41 100644 --- a/torch/_weights_only_unpickler.py +++ b/torch/_weights_only_unpickler.py @@ -23,6 +23,7 @@ # weights = torch.load(buf, weights_only = True) import functools as _functools +import warnings from collections import Counter, OrderedDict from pickle import ( APPEND, @@ -67,6 +68,16 @@ from struct import unpack from sys import maxsize from typing import Any, Dict, List +try: + # We rely on this module in private cPython which provides dicts of + # modules/functions that had their names changed from Python 2 to 3 + has_compat_pickle = True + from _compat_pickle import IMPORT_MAPPING, NAME_MAPPING +except ImportError: + # To prevent warning on import torch, we warn in the Unpickler.load below + has_compat_pickle = False + IMPORT_MAPPING, NAME_MAPPING = dict(), dict() + import torch _marked_safe_globals_list: List[Any] = [] @@ -97,7 +108,8 @@ def _clear_safe_globals(): def _get_user_allowed_globals(): rc: Dict[str, Any] = {} for f in _marked_safe_globals_list: - rc[f"{f.__module__}.{f.__name__}"] = f + module, name = f.__module__, f.__name__ + rc[f"{module}.{name}"] = f return rc @@ -170,12 +182,20 @@ class Unpickler: self.readline = file.readline self.read = file.read self.memo: Dict[int, Any] = {} + self.proto: int = -1 def load(self): """Read a pickled object representation from the open file. Return the reconstituted object hierarchy specified in the file. """ + if not has_compat_pickle: + warnings.warn( + "Could not import IMPORT_MAPPING and NAME_MAPPING from _compat_pickle. " + "If the default `pickle_protocol` was used at `torch.save` time, any functions or " + "classes that are in these maps might not behave correctly if allowlisted via " + "`torch.serialization.add_safe_globals()`." + ) self.metastack = [] self.stack: List[Any] = [] self.append = self.stack.append @@ -190,6 +210,13 @@ class Unpickler: if key[0] == GLOBAL[0]: module = readline()[:-1].decode("utf-8") name = readline()[:-1].decode("utf-8") + # Patch since torch.save default protocol is 2 + # users will be running this code in python > 3 + if self.proto == 2 and has_compat_pickle: + if (module, name) in NAME_MAPPING: + module, name = NAME_MAPPING[(module, name)] + elif module in IMPORT_MAPPING: + module = IMPORT_MAPPING[module] full_path = f"{module}.{name}" if full_path in _get_allowed_globals(): self.append(_get_allowed_globals()[full_path]) @@ -334,8 +361,14 @@ class Unpickler: self.append(decode_long(data)) # First and last deserializer ops elif key[0] == PROTO[0]: - # Read and ignore proto version - read(1)[0] + self.proto = read(1)[0] + if self.proto != 2: + warnings.warn( + f"Detected pickle protocol {self.proto} in the checkpoint, which was " + "not the default pickle protocol used by `torch.load` (2). The weights_only " + "Unpickler might not support all instructions implemented by this protocol, " + "please file an issue for adding support if you encounter this." + ) elif key[0] == STOP[0]: rc = self.stack.pop() return rc