From a979318ef7fabac7e0d7a2101e0e70af75fca7bd Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Fri, 1 Nov 2024 11:57:50 -0700 Subject: [PATCH] Add section to serialization note re weights_only (#139433) Pull Request resolved: https://github.com/pytorch/pytorch/pull/139433 Approved by: https://github.com/malfet ghstack dependencies: #138936, #139221 --- docs/source/notes/serialization.rst | 85 +++++++++++++++++++++++++++++ torch/serialization.py | 1 + 2 files changed, 86 insertions(+) diff --git a/docs/source/notes/serialization.rst b/docs/source/notes/serialization.rst index 525bf7dba4c..77a4ea5d042 100644 --- a/docs/source/notes/serialization.rst +++ b/docs/source/notes/serialization.rst @@ -176,6 +176,7 @@ can use this pattern: >>> new_m.load_state_dict(m_state_dict) + .. _serialized-file-format: Serialized file format for ``torch.save`` @@ -214,6 +215,90 @@ is 64-byte aligned. such, their storages are not serialized. In these cases ``data/`` might not exist in the checkpoint. +.. _weights-only: + +``torch.load`` with ``weights_only=True`` +----------------------------------------- + +Starting in version 2.6, ``torch.load`` will use ``weights_only=True`` if the ``pickle_module`` +argument is not passed. + +As discussed in the documentation for :func:`torch.load`, ``weights_only=True`` restricts +the unpickler used in ``torch.load`` to only executing functions/building classes required for +``state_dicts`` of plain ``torch.Tensors`` as well as some other primitive types. Further, +unlike the default ``Unpickler`` provided by the ``pickle`` module, the ``weights_only`` Unpickler +is not allowed to dynamically import anything during unpickling. + +As mentioned above, saving a module's ``state_dict`` is a best practice when using ``torch.save``. If loading an old +checkpoint that contains an ``nn.Module``, we recommend ``weights_only=False``. When loading a checkpoint that contains +tensor subclasses, there will likely be functions/classes that need to be allowlisted, see below for further details. + +If the ``weights_only`` Unpickler encounters a function or class that is not allowlisted +by default within the pickle file, you should see an actionable error like such + +.. code:: + + _pickle.UnpicklingError: Weights only load failed. This file can still be loaded, + to do so you have two options, do those steps only if you trust the source of the checkpoint. + 1. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, + but it can result in arbitrary code execution. Do it only if you got the file from a trusted source. + 2. Alternatively, to load with `weights_only=True` please check the recommended + steps in the following error message. + WeightsUnpickler error: Unsupported global: GLOBAL {__module__}.{__name__} was not an allowed global by + default. Please use `torch.serialization.add_safe_globals([{__name__}])` or the + `torch.serialization.safe_globals([{__name__}])` context manager to allowlist this global + if you trust this class/function. + +Please follow the steps in the error message and allowlist the functions or classes only if you trust them. + +To get all GLOBALs (functions/classes) in the checkpoint that are not yet allowlisted you can use +:func:`torch.serialization.get_unsafe_globals_in_checkpoint` which will return a list of strings of the form +``{__module__}.{__name__}``. If you trust these functions/classes, you can import them and allowlist them per +the error message either via :func:`torch.serialization.add_safe_globals` or the context manager +:class:`torch.serialization.safe_globals`. + +To access the list of user-allowlisted functions/classes you can use :func:`torch.serialization.get_safe_globals` and +to clear the current list see :func:`torch.serialization.clear_safe_globals`. + +Troubleshooting ``weights_only`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Getting unsafe globals +"""""""""""""""""""""" + +A caveat is that :func:`torch.serialization.get_unsafe_globals_in_checkpoint` analyzes the checkpoint statically, +some types might be built dynamically during the unpickling process and hence will not be reported by +:func:`torch.serialization.get_unsafe_globals_in_checkpoint`. One such example is ``dtypes`` in numpy. In +``numpy < 1.25`` after allowlisting all the functions/classes reported by +:func:`torch.serialization.get_unsafe_globals_in_checkpoint` you might see an error like + +.. code:: + + WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`, + but got + +This can be allowlisted via ``{add_}safe_globals([type(np.dtype(np.float32))])``. + +In ``numpy >=1.25`` you would see + +.. code:: + + WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`, + but got + +This can be allowlisted via ``{add_}safe_globals([np.dtypes.Float32DType])``. + +Environment Variables +""""""""""""""""""""" + +There are two environment variables that will influence the behavior of ``torch.load``. These can be helpful +if one does not have access to the ``torch.load`` callsites. + +* ``TORCH_FORCE_WEIGHTS_ONLY_LOAD=1`` will override all ``torch.load`` callsites to use ``weights_only=True``. +* ``TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1`` will make ``torch.load`` callsites use ``weights_only=False`` **only** + if ``weights_only`` was not passed as an argument. + + .. _serializing-python-modules: Serializing torch.nn.Modules and loading them in C++ diff --git a/torch/serialization.py b/torch/serialization.py index c270e3d1229..857e70c23a1 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -1248,6 +1248,7 @@ def load( weights_only: Indicates whether unpickler should be restricted to loading only tensors, primitive types, dictionaries and any types added via :func:`torch.serialization.add_safe_globals`. + See :ref:`weights-only` for more details. mmap: Indicates whether the file should be mmaped rather than loading all the storages into memory. Typically, tensor storages in the file will first be moved from disk to CPU memory, after which they are moved to the location that they were tagged with when saving, or specified by ``map_location``. This