diff --git a/docs/source/notes/serialization.rst b/docs/source/notes/serialization.rst index 09fd9e858b8..225486cdeda 100644 --- a/docs/source/notes/serialization.rst +++ b/docs/source/notes/serialization.rst @@ -394,3 +394,6 @@ The following utility functions are related to serialization: .. autofunction:: set_default_load_endianness .. autofunction:: get_default_mmap_options .. autofunction:: set_default_mmap_options +.. autofunction:: add_safe_globals +.. autofunction:: clear_safe_globals +.. autofunction:: get_safe_globals diff --git a/test/dynamo_expected_failures/TestSubclassSerialization.test_allowlist_for_weights_only b/test/dynamo_expected_failures/TestSubclassSerialization.test_allowlist_for_weights_only new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/test_serialization.py b/test/test_serialization.py index 5c6b78b4456..1be1b06ab78 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -15,8 +15,10 @@ import pickle import shutil import pathlib import platform +from collections import OrderedDict from copy import deepcopy from itertools import product +from types import ModuleType from torch._utils_internal import get_file_path_2 from torch._utils import _rebuild_tensor @@ -27,9 +29,10 @@ from torch.serialization import check_module_version_greater_or_equal, get_defau from torch.testing._internal.common_utils import ( IS_FILESYSTEM_UTF8_ENCODING, TemporaryDirectoryName, TestCase, IS_FBCODE, IS_WINDOWS, TEST_DILL, run_tests, download_file, BytesIOContext, TemporaryFileName, - parametrize, instantiate_parametrized_tests, AlwaysWarnTypedStorageRemoval, serialTest) + parametrize, instantiate_parametrized_tests, AlwaysWarnTypedStorageRemoval, serialTest, skipIfTorchDynamo) from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_dtype import all_types_and_complex_and +from torch.testing._internal.two_tensor import TwoTensor # noqa: F401 if not IS_WINDOWS: from mmap import MAP_SHARED, MAP_PRIVATE @@ -1038,7 +1041,7 @@ class TestSerialization(TestCase, SerializationMixin): self.assertIsNone(torch.load(f, weights_only=False)) f.seek(0) # Safe load should assert - with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported class"): + with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported global: GLOBAL __builtin__.print"): torch.load(f, weights_only=True) @parametrize('weights_only', (False, True)) @@ -4108,6 +4111,23 @@ class TestGetStateSubclass(torch.Tensor): class TestEmptySubclass(torch.Tensor): ... +# ONLY use SubclassSpoof subclasses for the subclass spoof tests since we modify them +# Cannot define locally in test or pickle will fail. +class TestEmptySubclassSpoof(TestEmptySubclass): + ... + +class TestWrapperSubclassSpoof(TestWrapperSubclass): + ... + +class RebuildFromTypeV2Spoof(torch.Tensor): + def __new__(cls, elem, naughty, **kwargs): + if naughty: + raise RuntimeError("naughty") + return super().__new__(cls, elem) + + def __reduce_ex__(self, protocol): + return (torch._tensor._rebuild_from_type_v2, (RebuildFromTypeV2Spoof, torch.Tensor, (True,), {})) + class TestSubclassSerialization(TestCase): def test_tensor_subclass_wrapper_serialization(self): @@ -4187,6 +4207,203 @@ class TestSubclassSerialization(TestCase): f.seek(0) tensor2 = torch.load(f) + def _create_bad_func(self, name): + def bad_func(self, *args, **kwargs): + raise RuntimeError(f"running {name}") + return bad_func + + @parametrize("wrapper", (True, False)) + def test_tensor_subclass_method_spoofing(self, wrapper): + ''' + This tests seeks to do the following: + - determine which methods of a tensor subclass might be called during unpickling (weights_only=False) + we consider these methods "risky" for weights_only + - ensure that we ban overriding this group of methods on a tensor subclass by default (weights_only=True) + - ensure that tensor subclass that doesn't override any of these can be unpickled (weights_only=True) + + We achieve this by overriding all methods of a tensor subclass to raise a RuntimeError + when called. We then try to unpickle a tensor subclass with weights_only=False and ensure that + only the RuntimeErrors that we expect are thrown. + + We then load with weights_only and ensure that weights_only will fail unless all the risky methods + are not overriden by resetting the risky methods to the non-overriden version in a loop and calling load. + The final weights_only load call when all the risky methods are no longer overriden. + ''' + subclass = TestWrapperSubclassSpoof if wrapper else TestEmptySubclassSpoof + t = subclass(torch.randn(2, 3)) + # To trigger setattr for the non-wrapper case + if not wrapper: + t.foo = 'bar' + inp = {'weight': t} + + with TemporaryFileName() as f: + torch.save(inp, f) + loaded = torch.load(f, weights_only=True) + self.assertEqual(loaded['weight'], inp['weight']) + + restore_methods = dict() + methods = [func for func in dir(subclass) if callable(getattr(subclass, func))] + for method in methods: + if method != "__class__": + restore_methods[method] = getattr(subclass, method) + setattr(subclass, method, self._create_bad_func(method)) + # These additional methods might be called during getattr or setattr + # but are not in methods above (not defined on tensor base class) + subclass.__get__ = self._create_bad_func("__get__") + subclass.__set__ = self._create_bad_func("__set__") + subclass.__getattr__ = self._create_bad_func("__getattr__") + restore_methods["__get__"] = None + restore_methods["__getattr__"] = None + restore_methods["__set__"] = None + + try: + # Check that weights_only=False load raises the RuntimeErrors we expect + with self.assertRaisesRegex(RuntimeError, "running __getattribute__"): + torch.load(f, weights_only=False) + subclass.__getattribute__ = restore_methods['__getattribute__'] + with self.assertRaisesRegex(RuntimeError, "running __setstate__"): + torch.load(f, weights_only=False) + subclass.__setstate__ = restore_methods['__setstate__'] + with self.assertRaisesRegex(RuntimeError, "running __setattr__"): + torch.load(f, weights_only=False) + subclass.__setattr__ = restore_methods['__setattr__'] + # should finally work + torch.load(f, weights_only=False) + + # Check that weights_only=True catches that risky methods are overriden + subclass.__setstate__ = self._create_bad_func("__setstate__") + subclass.__getattribute__ = self._create_bad_func("__getattribute__") + subclass.__setattr__ = self._create_bad_func("__setattr__") + with self.assertRaisesRegex(pickle.UnpicklingError, + "methods: __getattribute__=True __getattr__=True __get__=True " + "__setattr__=True __set__=True __setstate__=True"): + torch.load(f, weights_only=True) + risky_methods = ['__get__', '__set__', '__getattr__', '__setattr__', '__getattribute__', '__setstate__'] + for i, meth in enumerate(risky_methods): + setattr(subclass, meth, restore_methods[meth]) + if i != len(risky_methods) - 1: + # When the given methods are not all back to default, load should still throw + # but reflect which methods are no longer overriden + with self.assertRaisesRegex(pickle.UnpicklingError, f"{meth}=False"): + torch.load(f, weights_only=True) + else: + # When the given methods are all back to default, weights_only load should finally work + loaded = torch.load(f, weights_only=True) + finally: + for method, func in restore_methods.items(): + setattr(subclass, method, func) + a = subclass(torch.randn(2, 3)) + + @skipIfTorchDynamo("name 'SYNTHETIC_LOCAL' is not defined") + def test_safe_globals_for_weights_only(self): + ''' + Tests import semantic for tensor subclass and the {add/get/clear}_safe_globals APIs + ''' + # Needed to prevent UnboundLocalError: local variable 'TwoTensor' referenced before assignment + global TwoTensor + t = TwoTensor(torch.randn(2, 3), torch.randn(2, 3)) + p = torch.nn.Parameter(t) + sd = OrderedDict([('t', t), ('p', p)]) + + with tempfile.NamedTemporaryFile() as f: + torch.save(sd, f) + # unimport TwoTensor + try: + del sys.modules['torch.testing._internal.two_tensor'] + + # Loading tensor subclass with weights_only=True should fail + # if tensor subclass has not been imported + with self.assertRaisesRegex(pickle.UnpicklingError, + "expect `torch.testing._internal.two_tensor` to be present in `sys.modules`"): + f.seek(0) + sd = torch.load(f, weights_only=True) + + # Loading tensor subclass with weights_only=True should work + # if target methods are not overriden and user has imported the subclass + from torch.testing._internal.two_tensor import TwoTensor + f.seek(0) + sd = torch.load(f, weights_only=True) + self.assertEqual(sd['t'], t) + self.assertEqual(sd['p'], p) + + # Loading tensor subclass with weights_only=True should fail + # if __setstate__ is overriden + f.seek(0) + restore_setstate = TwoTensor.__setstate__ + try: + TwoTensor.__setstate__ = lambda self, state: self.__dict__.update(state) + with self.assertRaisesRegex(pickle.UnpicklingError, "__setstate__=True"): + torch.load(f, weights_only=True) + + # Loading tensor subclass with overriden __setstate__ with weights_only=True should work + # if the class is marked safe + f.seek(0) + torch.serialization.add_safe_globals([TwoTensor]) + self.assertTrue(torch.serialization.get_safe_globals() == [TwoTensor]) + sd = torch.load(f, weights_only=True) + self.assertEqual(sd['t'], t) + self.assertEqual(sd['p'], p) + + # Should fail again when safe globals are cleared + torch.serialization.clear_safe_globals() + f.seek(0) + with self.assertRaisesRegex(pickle.UnpicklingError, "__setstate__=True"): + torch.load(f, weights_only=True) + finally: + TwoTensor.__setstate__ = restore_setstate + finally: + from torch.testing._internal.two_tensor import TwoTensor + + + def test_tensor_subclass_parent_module_method_spoofing(self): + ''' + Tests that weights_only load does not call any methods of the parent module + that contains the tensor subclass. + + We achieve this by overriding all methods of a module we add to sys.modules to raise a RuntimeError + when called. We then try to unpickle a tensor subclass with weights_only=True and ensure that + no RuntimeErrors are thrown. + ''' + # Simulates user doing `import spoof_mod` where `spoof_mod` contains `TestEmptySubclass` + class SpoofModule(ModuleType): + pass + + spoof_mod = SpoofModule('bla') + spoof_mod.TestEmptySubclass = TestEmptySubclass + inp = {'weight': TestEmptySubclass(torch.randn(2, 3))} + TestEmptySubclass.__module__ = 'spoof_mod' + sys.modules['spoof_mod'] = spoof_mod + + try: + with TemporaryFileName() as f: + torch.save(inp, f) + torch.load(f, weights_only=True) + restore_methods = dict() + methods = [func for func in dir(SpoofModule) if callable(getattr(SpoofModule, func))] + for method in methods: + if method != "__class__": + restore_methods[method] = getattr(SpoofModule, method) + setattr(SpoofModule, method, self._create_bad_func(method)) + SpoofModule.__get__ = self._create_bad_func("__get__") + SpoofModule.__getattr__ = self._create_bad_func("__getattr__") + loaded = torch.load(f, weights_only=True) + self.assertEqual(loaded['weight'], inp['weight']) + finally: + TestEmptySubclass.__module__ = __name__ + del sys.modules['spoof_mod'] + + def test_rebuild_from_type_v2_spoof(self): + t = RebuildFromTypeV2Spoof(torch.randn(2, 3), False) + inp = {'weight': t} + + with TemporaryFileName() as f: + torch.save(inp, f) + # subclass will be pushed onto unpickler's stack as a string + # and only gets converted to the type if it is argument 1 to _rebuild_from_type_v2 + with self.assertRaisesRegex(TypeError, "'str' object is not callable"): + loaded = torch.load(f, weights_only=True) + + instantiate_device_type_tests(TestBothSerialization, globals()) instantiate_parametrized_tests(TestSubclassSerialization) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index ac70396c468..0599da2117f 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1196,6 +1196,7 @@ def _has_storage(x: Tensor) -> _bool: ... def _construct_storage_from_data_pointer(data_ptr: _int, device: torch.device, size: _int) -> Storage: ... def _should_allow_numbers_as_tensors(func_name: str) -> _bool: ... def _group_tensors_by_device_and_dtype(nested_tensorlists: List[List[Optional[Tensor]]], with_indices: _bool = False) -> Dict[Tuple[torch.device, str], Tuple[List[List[Optional[Tensor]]], List[_int]]]: ... +def _check_tp_alloc_is_default(cls: Type) -> _bool: ... # NB: There is no Capsule type in typing, see # https://code.activestate.com/lists/python-dev/139675/ diff --git a/torch/_weights_only_unpickler.py b/torch/_weights_only_unpickler.py index 44dd8223862..6c9f3b61ae8 100644 --- a/torch/_weights_only_unpickler.py +++ b/torch/_weights_only_unpickler.py @@ -9,6 +9,10 @@ # - `torch.nn.Parameter` # - `collections.Counter` # - `collections.OrderedDict` +# Additionally, users can use an allowlist for adding classes they have deemed as safe using +# `_add_safe_globals()` (`torch.serialization.add_safe_globals`) +# `_clear_safe_globals()` (`torch.serialization.clear_safe_globals`) +# `_get_safe_globals()` (`torch.serialization.get_safe_globals`) # Based of https://github.com/python/cpython/blob/main/Lib/pickle.py # Expected to be useful for loading PyTorch model weights @@ -19,6 +23,7 @@ import functools as _functools from collections import Counter, OrderedDict +from inspect import getattr_static from pickle import ( APPEND, APPENDS, @@ -59,11 +64,57 @@ from pickle import ( UnpicklingError, ) from struct import unpack -from sys import maxsize -from typing import Any, Dict, List +from sys import maxsize, modules +from typing import Any, Dict, List, Type import torch +_marked_safe_globals_list: List[Any] = [] + + +def _add_safe_globals(safe_globals: List[Any]): + global _marked_safe_globals_list + _marked_safe_globals_list += safe_globals + + +def _get_safe_globals() -> List[Any]: + global _marked_safe_globals_list + return _marked_safe_globals_list + + +def _clear_safe_globals(): + global _marked_safe_globals_list + _marked_safe_globals_list = [] + + +# Separate from _get_allowed_globals because of the lru_cache on _get_allowed_globals +# For example if user had a script like +# torch.load(file_a) +# torch.serialization._add_safe_globals([torch.foo]) +# torch.load(file_b) +# the dynamic additions to safe_globals would not be picked up by +# _get_allowed_globals due to the lru_cache +def _get_user_allowed_globals(): + rc: Dict[str, Any] = {} + for f in _marked_safe_globals_list: + rc[f"{f.__module__}.{f.__name__}"] = f + return rc + + +def _tensor_rebuild_functions(): + return { + torch._utils._rebuild_parameter, + torch._utils._rebuild_parameter_with_state, + torch._utils._rebuild_qtensor, + torch._utils._rebuild_tensor, + torch._utils._rebuild_tensor_v2, + torch._utils._rebuild_tensor_v3, + torch._utils._rebuild_sparse_tensor, + torch._utils._rebuild_meta_tensor_no_storage, + torch._utils._rebuild_nested_tensor, + torch._utils._rebuild_wrapper_subclass, + } + # Unpickling machinery @_functools.lru_cache(maxsize=1) @@ -75,6 +126,7 @@ def _get_allowed_globals(): "torch.serialization._get_layout": torch.serialization._get_layout, "torch.Size": torch.Size, "torch.Tensor": torch.Tensor, + "torch.device": torch.device, } # dtype for t in torch.storage._dtype_to_storage_type_map().keys(): @@ -103,17 +155,7 @@ def _get_allowed_globals(): ]: rc[str(qt)] = qt # Rebuild functions - for f in [ - torch._utils._rebuild_parameter, - torch._utils._rebuild_parameter_with_state, - torch._utils._rebuild_qtensor, - torch._utils._rebuild_tensor, - torch._utils._rebuild_tensor_v2, - torch._utils._rebuild_tensor_v3, - torch._utils._rebuild_sparse_tensor, - torch._utils._rebuild_meta_tensor_no_storage, - torch._utils._rebuild_nested_tensor, - ]: + for f in _tensor_rebuild_functions(): rc[f"torch._utils.{f.__name__}"] = f # Handles Tensor Subclasses, Tensor's with attributes. @@ -128,6 +170,11 @@ class Unpickler: self.readline = file.readline self.read = file.read self.memo: Dict[int, Any] = {} + # tensor subclass types found from GLOBAL instructions that have passed the criteria + # to be allowed as the second argument to `torch._tensor._rebuild_from_type_v2` + # This enables rebuilding of tensor subclasses defined outside the `torch` package. + # See [Note: Criteria for allowing out-of-core tensor subclasses] for details on the criteria. + self.tensor_subclasses_found: Dict[str, Type] = {} def load(self): """Read a pickled object representation from the open file. @@ -151,8 +198,124 @@ class Unpickler: full_path = f"{module}.{name}" if full_path in _get_allowed_globals(): self.append(_get_allowed_globals()[full_path]) + elif full_path in _get_user_allowed_globals(): + self.append(_get_user_allowed_globals()[full_path]) else: - raise RuntimeError(f"Unsupported class {full_path}") + # The logic in this branch handles user-defined tensor subclasses. + # We can automatically allow and raise and error for anything that is not provably safe. + # [Note: Criteria for allowing out-of-core tensor subclasses] + # GLOBAL '.' instructions will get the class and + # push the string (not the actual type) while adding the type to the dictionary keyed + # by the string onto the unpickler's stack if they satisfy the following conditions: + # (1) The that defines them is in `sys.modules` + # (we will use getattr_static to access it to ensure no code execution) + # (2) They inherit from `torch.Tensor` + # (2) The class is not overriding any of the `torch.Tensor` methods listed here: + # `__getattr__`, `__get__`, `__getattribute__`, `__setstate__`, `__set__`, + # and `tp_alloc` + # The methods that we ban overriding were selected in a test-driven manner + # by overriding every callable method on a tensor subclass and determinining + # which might get called during unpickling. + # When executing REDUCE, the string will be appropriately converted back to the type only + # for `torch._tensor._rebuild_from_type_v2` as other use of the class could use methods + # we didn't audit. + if module == "__builtin__": + raise RuntimeError( + f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. " + "Please use `torch.serialization.add_safe_globals` to allowlist this global " + "if you trust this class/function." + ) + elif module not in modules: + # TODO: add a link here to a doc that explains to users what we mean by trust + raise RuntimeError( + f"Found GLOBAL `{full_path}` instruction in the pickle file but `{full_path}` was " + f"not in the pre-defined list of allowed globals that are considered safe by the " + "weights_only unpickler for rebuilding state_dicts. This is the expected behavior if " + f"`{full_path}` is a class or function that is not in the list of allowed globals " + f"If `{full_path}` is NOT a tensor subclass, you might consider" + "`torch.serialization.add_safe_globals` if it is appropriate. However, if it is a " + "user-defined tensor subclass not defined in the `torch` package, this error might arise " + f"as we expect `{module}` to be present in `sys.modules` (i.e. it " + "must be imported in the current environment), but this was not the case. " + f"If you intend to unpickle a tensor subclass `{full_path}` please import `{name}` from " + f"`{module}`. Note that having this imported will *only* allow the type `{full_path}` to " + "be passed as the second argument to `torch._tensor._rebuild_from_type_v2`, which should " + "enable the tensor subclass to be unpickled without any arbitrary code execution as long " + # If the user imports and these are overridden the next error will prompt them to use + # torch.serialization.add_safe_globals. + "a sa pre-defined list of methods called when unpickling are not overridden. In " + "particular, the methods are `__getattr__`, `__get__`, `__getattribute__`, `__setstate__`, " + "`__set__`, as well as the implementation of `tp_alloc`." + ) + else: + try: + class_type = getattr_static(modules[module], name) + except AttributeError as e: + raise AttributeError( + "For safety during weights_only loading, we use inspect.getattr_state to " + f"get {name} from {module}, if {module} implements the descriptor protocol, " + "__getattr__ or __getattribute__ these will not be called." + ) from e + # None of the objects here contain any data from the pickle so this is safe + if isinstance(class_type, type) and issubclass( + class_type, torch.Tensor + ): + # getattr is called by the getattr call in `_rebuild_from_type_v2` + custom_get_attribute = ( + class_type.__getattribute__ + is not torch.Tensor.__getattribute__ + ) + custom_get = ( + getattr_static(class_type, "__get__", None) is not None + ) + custom_get_attr = ( + getattr_static(class_type, "__getattr__", None) + is not None + ) + # Tensor.__setstate__ might be called in `_rebuild_from_type_v2` + custom_set_state = ( + class_type.__setstate__ is not torch.Tensor.__setstate__ + ) + # setattr is called in `torch._utils._set_obj_state` + custom_set_attr = ( + class_type.__setattr__ is not object.__setattr__ + ) + custom_set = ( + getattr_static(class_type, "__set__", None) is not None + ) + # tp_alloc is called by `Tensor._rebuild_wrapper_subclass` and `Tensor.as_subclass` + has_custom_tp_alloc = ( + not torch._C._check_tp_alloc_is_default(class_type) + ) + custom_methods = { + "__getattribute__": custom_get_attribute, + "__getattr__": custom_get_attr, + "__get__": custom_get, + "__setattr__": custom_set_attr, + "__set__": custom_set, + "__setstate__": custom_set_state, + "tp_alloc": has_custom_tp_alloc, + } + if any(custom_methods.values()): + error = "" + for k, v in custom_methods.items(): + error += f" {k}={v}" + raise RuntimeError( + f"Trying to unpickle tensor subclass `{full_path}` that has defined a custom " + f"version for one of these methods:{error}. Please check whether you trust these " + "methods and allowlist the subclass with `torch.serialization.add_safe_globals` if so." + ) + # push the string full_path onto the stack (in REBUILD, there is special logic to + # access this from tensor_subclasses_found for rebuild_from_type_v2) + self.tensor_subclasses_found[full_path] = class_type + self.append(full_path) + else: + raise RuntimeError( + f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. " + "Please use `torch.serialization.add_safe_globals` to allowlist this global " + "if you trust this class/function." + ) + elif key[0] == NEWOBJ[0]: args = self.stack.pop() cls = self.stack.pop() @@ -162,10 +325,33 @@ class Unpickler: elif key[0] == REDUCE[0]: args = self.stack.pop() func = self.stack[-1] - if func not in _get_allowed_globals().values(): + if ( + func not in _get_allowed_globals().values() + and func not in _get_user_allowed_globals().values() + ): raise RuntimeError( f"Trying to call reduce for unrecognized function {func}" ) + # Special handling for tensor subclass type found in GLOBAL that is pushed + # onto stack as str to prevent it from being used anywhere except the + # second arg of _rebuild_from_type_v2 and within argument tuple for _rebuild_wrapper_subclass + # _rebuild_from_type_v2 is called with args (func, type, func_args, state) + # where both type and, when func is rebuild_wrapper_subclass, func_args[0] could be the subclass type + # Since we pushed these subclass types onto the stack as strings, convert them to the actual + # type here. + if func is torch._tensor._rebuild_from_type_v2 and type(args[1]) is str: + args_after = args[2:] + if ( + args[0] is torch._utils._rebuild_wrapper_subclass + and type(args[2][0]) is str + ): + new_arg_tuple = ( + self.tensor_subclasses_found[args[2][0]], + ) + args[2][1:] + args_after = (new_arg_tuple,) + args[3:] + args = ( + args[:1] + (self.tensor_subclasses_found[args[1]],) + args_after + ) self.stack[-1] = func(*args) elif key[0] == BUILD[0]: state = self.stack.pop() diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 3be764220e0..9ff9131435f 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -422,6 +422,19 @@ PyObject* THPModule_swap_tensor_impl(PyObject* _unused, PyObject* args) { END_HANDLE_TH_ERRORS } +PyObject* THPModule_check_tp_alloc_is_default( + PyObject* _unused, + PyObject* cls) { + HANDLE_TH_ERRORS + TORCH_CHECK_TYPE( + PyType_Check(cls), + "cls must be a type (got ", + Py_TYPE(cls)->tp_name, + ")"); + return PyBool_FromLong(Py_TYPE(cls)->tp_alloc == PyType_GenericAlloc); + END_HANDLE_TH_ERRORS +} + PyObject* THPModule_addDocStr(PyObject* _unused, PyObject* args) { // adds a __doc__ string to a function, similar to numpy's arr_add_docstring static std::vector all_docs; @@ -1268,6 +1281,10 @@ static PyMethodDef TorchMethods[] = { // NOLINT {"_autograd_init", THPAutograd_initExtension, METH_NOARGS, nullptr}, {"_add_docstr", THPModule_addDocStr, METH_VARARGS, nullptr}, {"_swap_tensor_impl", THPModule_swap_tensor_impl, METH_VARARGS, nullptr}, + {"_check_tp_alloc_is_default", + THPModule_check_tp_alloc_is_default, + METH_O, + nullptr}, {"_init_names", THPModule_initNames, METH_O, nullptr}, {"_has_distributed", THPModule_hasDistributed, METH_NOARGS, nullptr}, {"_set_default_tensor_type", diff --git a/torch/serialization.py b/torch/serialization.py index 64a1e6e0ce0..a7703b9964d 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -59,6 +59,9 @@ __all__ = [ 'LoadEndianness', 'get_default_load_endianness', 'set_default_load_endianness', + 'clear_safe_globals', + 'get_safe_globals', + 'add_safe_globals', ] @@ -148,6 +151,27 @@ def set_default_mmap_options(flags: int): f"expected mmap.MAP_PRIVATE or mmap.MAP_SHARED, but got {flags}") _default_mmap_options = flags +def clear_safe_globals() -> None: + ''' + Clears the list of globals that are safe for ``weights_only`` load. + ''' + _weights_only_unpickler._clear_safe_globals() + +def get_safe_globals() -> List[Any]: + ''' + Returns the list of user-added globals that are safe for ``weights_only`` load. + ''' + return _weights_only_unpickler._get_safe_globals() + +def add_safe_globals(safe_globals: List[Any]) -> None: + ''' + Marks the given globals as safe for ``weights_only`` load. + + Args: + safe_globals (List[Any]): list of globals to mark as safe + ''' + _weights_only_unpickler._add_safe_globals(safe_globals) + def _is_zipfile(f) -> bool: # This is a stricter implementation than zipfile.is_zipfile(). # zipfile.is_zipfile() is True if the magic number appears anywhere in the @@ -952,7 +976,9 @@ def load( UNSAFE_MESSAGE = ( "Weights only load failed. Re-running `torch.load` with `weights_only` set to `False`" " will likely succeed, but it can result in arbitrary code execution." - "Do it only if you get the file from a trusted source. WeightsUnpickler error: " + " Do it only if you get the file from a trusted source. Alternatively, to load" + " with `weights_only` please check the recommended steps in the following error message." + " WeightsUnpickler error: " ) # Add ability to force safe only weight loads via environment variable if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in ['1', 'y', 'yes', 'true']: