Add section to serialization note re weights_only (#139433)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139433
Approved by: https://github.com/malfet
ghstack dependencies: #138936, #139221
This commit is contained in:
Mikayla Gawarecki 2024-11-01 11:57:50 -07:00 committed by PyTorch MergeBot
parent a1f854f270
commit a979318ef7
2 changed files with 86 additions and 0 deletions

View file

@ -176,6 +176,7 @@ can use this pattern:
>>> new_m.load_state_dict(m_state_dict)
<All keys matched successfully>
.. _serialized-file-format:
Serialized file format for ``torch.save``
@ -214,6 +215,90 @@ is 64-byte aligned.
such, their storages are not serialized. In these cases ``data/`` might not exist
in the checkpoint.
.. _weights-only:
``torch.load`` with ``weights_only=True``
-----------------------------------------
Starting in version 2.6, ``torch.load`` will use ``weights_only=True`` if the ``pickle_module``
argument is not passed.
As discussed in the documentation for :func:`torch.load`, ``weights_only=True`` restricts
the unpickler used in ``torch.load`` to only executing functions/building classes required for
``state_dicts`` of plain ``torch.Tensors`` as well as some other primitive types. Further,
unlike the default ``Unpickler`` provided by the ``pickle`` module, the ``weights_only`` Unpickler
is not allowed to dynamically import anything during unpickling.
As mentioned above, saving a module's ``state_dict`` is a best practice when using ``torch.save``. If loading an old
checkpoint that contains an ``nn.Module``, we recommend ``weights_only=False``. When loading a checkpoint that contains
tensor subclasses, there will likely be functions/classes that need to be allowlisted, see below for further details.
If the ``weights_only`` Unpickler encounters a function or class that is not allowlisted
by default within the pickle file, you should see an actionable error like such
.. code::
_pickle.UnpicklingError: Weights only load failed. This file can still be loaded,
to do so you have two options, do those steps only if you trust the source of the checkpoint.
1. Re-running `torch.load` with `weights_only` set to `False` will likely succeed,
but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
2. Alternatively, to load with `weights_only=True` please check the recommended
steps in the following error message.
WeightsUnpickler error: Unsupported global: GLOBAL {__module__}.{__name__} was not an allowed global by
default. Please use `torch.serialization.add_safe_globals([{__name__}])` or the
`torch.serialization.safe_globals([{__name__}])` context manager to allowlist this global
if you trust this class/function.
Please follow the steps in the error message and allowlist the functions or classes only if you trust them.
To get all GLOBALs (functions/classes) in the checkpoint that are not yet allowlisted you can use
:func:`torch.serialization.get_unsafe_globals_in_checkpoint` which will return a list of strings of the form
``{__module__}.{__name__}``. If you trust these functions/classes, you can import them and allowlist them per
the error message either via :func:`torch.serialization.add_safe_globals` or the context manager
:class:`torch.serialization.safe_globals`.
To access the list of user-allowlisted functions/classes you can use :func:`torch.serialization.get_safe_globals` and
to clear the current list see :func:`torch.serialization.clear_safe_globals`.
Troubleshooting ``weights_only``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Getting unsafe globals
""""""""""""""""""""""
A caveat is that :func:`torch.serialization.get_unsafe_globals_in_checkpoint` analyzes the checkpoint statically,
some types might be built dynamically during the unpickling process and hence will not be reported by
:func:`torch.serialization.get_unsafe_globals_in_checkpoint`. One such example is ``dtypes`` in numpy. In
``numpy < 1.25`` after allowlisting all the functions/classes reported by
:func:`torch.serialization.get_unsafe_globals_in_checkpoint` you might see an error like
.. code::
WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`,
but got <class 'numpy.dtype[float32]'>
This can be allowlisted via ``{add_}safe_globals([type(np.dtype(np.float32))])``.
In ``numpy >=1.25`` you would see
.. code::
WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`,
but got <class 'numpy.dtypes.Float32DType'>
This can be allowlisted via ``{add_}safe_globals([np.dtypes.Float32DType])``.
Environment Variables
"""""""""""""""""""""
There are two environment variables that will influence the behavior of ``torch.load``. These can be helpful
if one does not have access to the ``torch.load`` callsites.
* ``TORCH_FORCE_WEIGHTS_ONLY_LOAD=1`` will override all ``torch.load`` callsites to use ``weights_only=True``.
* ``TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1`` will make ``torch.load`` callsites use ``weights_only=False`` **only**
if ``weights_only`` was not passed as an argument.
.. _serializing-python-modules:
Serializing torch.nn.Modules and loading them in C++

View file

@ -1248,6 +1248,7 @@ def load(
weights_only: Indicates whether unpickler should be restricted to
loading only tensors, primitive types, dictionaries
and any types added via :func:`torch.serialization.add_safe_globals`.
See :ref:`weights-only` for more details.
mmap: Indicates whether the file should be mmaped rather than loading all the storages into memory.
Typically, tensor storages in the file will first be moved from disk to CPU memory, after which they
are moved to the location that they were tagged with when saving, or specified by ``map_location``. This