mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add section to serialization note re weights_only (#139433)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139433 Approved by: https://github.com/malfet ghstack dependencies: #138936, #139221
This commit is contained in:
parent
a1f854f270
commit
a979318ef7
2 changed files with 86 additions and 0 deletions
|
|
@ -176,6 +176,7 @@ can use this pattern:
|
|||
>>> new_m.load_state_dict(m_state_dict)
|
||||
<All keys matched successfully>
|
||||
|
||||
|
||||
.. _serialized-file-format:
|
||||
|
||||
Serialized file format for ``torch.save``
|
||||
|
|
@ -214,6 +215,90 @@ is 64-byte aligned.
|
|||
such, their storages are not serialized. In these cases ``data/`` might not exist
|
||||
in the checkpoint.
|
||||
|
||||
.. _weights-only:
|
||||
|
||||
``torch.load`` with ``weights_only=True``
|
||||
-----------------------------------------
|
||||
|
||||
Starting in version 2.6, ``torch.load`` will use ``weights_only=True`` if the ``pickle_module``
|
||||
argument is not passed.
|
||||
|
||||
As discussed in the documentation for :func:`torch.load`, ``weights_only=True`` restricts
|
||||
the unpickler used in ``torch.load`` to only executing functions/building classes required for
|
||||
``state_dicts`` of plain ``torch.Tensors`` as well as some other primitive types. Further,
|
||||
unlike the default ``Unpickler`` provided by the ``pickle`` module, the ``weights_only`` Unpickler
|
||||
is not allowed to dynamically import anything during unpickling.
|
||||
|
||||
As mentioned above, saving a module's ``state_dict`` is a best practice when using ``torch.save``. If loading an old
|
||||
checkpoint that contains an ``nn.Module``, we recommend ``weights_only=False``. When loading a checkpoint that contains
|
||||
tensor subclasses, there will likely be functions/classes that need to be allowlisted, see below for further details.
|
||||
|
||||
If the ``weights_only`` Unpickler encounters a function or class that is not allowlisted
|
||||
by default within the pickle file, you should see an actionable error like such
|
||||
|
||||
.. code::
|
||||
|
||||
_pickle.UnpicklingError: Weights only load failed. This file can still be loaded,
|
||||
to do so you have two options, do those steps only if you trust the source of the checkpoint.
|
||||
1. Re-running `torch.load` with `weights_only` set to `False` will likely succeed,
|
||||
but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
|
||||
2. Alternatively, to load with `weights_only=True` please check the recommended
|
||||
steps in the following error message.
|
||||
WeightsUnpickler error: Unsupported global: GLOBAL {__module__}.{__name__} was not an allowed global by
|
||||
default. Please use `torch.serialization.add_safe_globals([{__name__}])` or the
|
||||
`torch.serialization.safe_globals([{__name__}])` context manager to allowlist this global
|
||||
if you trust this class/function.
|
||||
|
||||
Please follow the steps in the error message and allowlist the functions or classes only if you trust them.
|
||||
|
||||
To get all GLOBALs (functions/classes) in the checkpoint that are not yet allowlisted you can use
|
||||
:func:`torch.serialization.get_unsafe_globals_in_checkpoint` which will return a list of strings of the form
|
||||
``{__module__}.{__name__}``. If you trust these functions/classes, you can import them and allowlist them per
|
||||
the error message either via :func:`torch.serialization.add_safe_globals` or the context manager
|
||||
:class:`torch.serialization.safe_globals`.
|
||||
|
||||
To access the list of user-allowlisted functions/classes you can use :func:`torch.serialization.get_safe_globals` and
|
||||
to clear the current list see :func:`torch.serialization.clear_safe_globals`.
|
||||
|
||||
Troubleshooting ``weights_only``
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Getting unsafe globals
|
||||
""""""""""""""""""""""
|
||||
|
||||
A caveat is that :func:`torch.serialization.get_unsafe_globals_in_checkpoint` analyzes the checkpoint statically,
|
||||
some types might be built dynamically during the unpickling process and hence will not be reported by
|
||||
:func:`torch.serialization.get_unsafe_globals_in_checkpoint`. One such example is ``dtypes`` in numpy. In
|
||||
``numpy < 1.25`` after allowlisting all the functions/classes reported by
|
||||
:func:`torch.serialization.get_unsafe_globals_in_checkpoint` you might see an error like
|
||||
|
||||
.. code::
|
||||
|
||||
WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`,
|
||||
but got <class 'numpy.dtype[float32]'>
|
||||
|
||||
This can be allowlisted via ``{add_}safe_globals([type(np.dtype(np.float32))])``.
|
||||
|
||||
In ``numpy >=1.25`` you would see
|
||||
|
||||
.. code::
|
||||
|
||||
WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`,
|
||||
but got <class 'numpy.dtypes.Float32DType'>
|
||||
|
||||
This can be allowlisted via ``{add_}safe_globals([np.dtypes.Float32DType])``.
|
||||
|
||||
Environment Variables
|
||||
"""""""""""""""""""""
|
||||
|
||||
There are two environment variables that will influence the behavior of ``torch.load``. These can be helpful
|
||||
if one does not have access to the ``torch.load`` callsites.
|
||||
|
||||
* ``TORCH_FORCE_WEIGHTS_ONLY_LOAD=1`` will override all ``torch.load`` callsites to use ``weights_only=True``.
|
||||
* ``TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1`` will make ``torch.load`` callsites use ``weights_only=False`` **only**
|
||||
if ``weights_only`` was not passed as an argument.
|
||||
|
||||
|
||||
.. _serializing-python-modules:
|
||||
|
||||
Serializing torch.nn.Modules and loading them in C++
|
||||
|
|
|
|||
|
|
@ -1248,6 +1248,7 @@ def load(
|
|||
weights_only: Indicates whether unpickler should be restricted to
|
||||
loading only tensors, primitive types, dictionaries
|
||||
and any types added via :func:`torch.serialization.add_safe_globals`.
|
||||
See :ref:`weights-only` for more details.
|
||||
mmap: Indicates whether the file should be mmaped rather than loading all the storages into memory.
|
||||
Typically, tensor storages in the file will first be moved from disk to CPU memory, after which they
|
||||
are moved to the location that they were tagged with when saving, or specified by ``map_location``. This
|
||||
|
|
|
|||
Loading…
Reference in a new issue