2024-06-08 18:24:41 +00:00
|
|
|
# mypy: allow-untyped-defs
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
# Unpickler restricted to loading only state dicts
|
|
|
|
|
# Restrict constructing types to a list defined in _get_allowed_globals()
|
|
|
|
|
# Restrict BUILD operation to `Tensor`, `Parameter` and `OrderedDict` types only
|
|
|
|
|
# Restrict APPEND/APPENDS to `list`
|
|
|
|
|
# In `GLOBALS` operation do not do class lookup by name, but rather rely on dictionary
|
|
|
|
|
# defined by `_get_allowed_globals()` method, that contains:
|
|
|
|
|
# - torch types (Storage, dtypes, Tensor, `torch.Size`),
|
|
|
|
|
# - `torch._utils._rebuild` functions.
|
|
|
|
|
# - `torch.nn.Parameter`
|
2024-04-16 17:00:50 +00:00
|
|
|
# - `collections.Counter`
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
# - `collections.OrderedDict`
|
Allow tensor subclasses and add `torch.serialization.add_safe_globals` that allows users to allowlist classes for `weights_only` load (#124331)
#### Conditions for allowlisting tensor subclasses
We allow tensor subclasses types that
(1) Do not override `__setstate__`, `__getattr__`, `__setattr__`, `__get__`, `__set__` or `__getattribute__` of `torch.Tensor` (`torch.Tensor` does not have a definition of `__getattr__`, `__get__` or `__set__` so we check that these are `None`)
(2) Use the generic `tp_alloc`
(3) Are in a module that *has been imported by the user*
to be pushed onto the stack as strings by `GLOBAL` instructions, while storing the type in a dict
The strings will be converted to the classes as appropriate when executing `REBUILD` with `_rebuild_from_type_v2`
*Note that we use `inspect.getattr_static(sys.modules[module], name)` to get the class/function as this method claims to have no code execution.
The rationale for the 3 conditions above is as follows:
The rebuild func provided by `Tensor.__reduce_ex__` is `torch._tensor._rebuild_from_type_v2`, which is defined as such (note the call to `getattr`, `Tensor.__setstate__` and the call to `as_subclass` as well as the call to `_set_obj_state` which calls `setattr`)
https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/_tensor.py#L57-L71
`as_subclass` is implemented with a call to `THPVariable_NewWithVar`
that will eventually call `tp_alloc` here
https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/csrc/autograd/python_variable.cpp#L2053
The `func` arg to `_rebuild_from_type_v2` for wrapper subclasses is `Tensor.rebuild_wrapper_subclass`, which will similarly call into `THPVariable_NewWithVar` and hit the above `tp_alloc`
**Note that we do not call `tp_init` or `tp_new` (i.e. `cls.__init__` or `cls.__new__`) when unpickling**
### How do we check something is a tensor subclass/constraints around imports
In order to check whether `bla` is a tensor subclass in the bytecode `GLOBAL module.name`, we need to do an `issubclass` check, which entails converting the global string to the appropriate type. We *do not* arbitrarily import modules but will perform this check as long as the given subclass (given by `module.name`) has already been imported by the user (i.e. `module in sys.modules` and `issubclass(getattr(sys[modules], name), torch.Tensor)`
This PR also allowlisted `torch._utils._rebuild_wrapper_subclass` and `torch.device` (used by `_rebuild_wrapper_subclass`)
### API for allow listing
This PR also added `torch.serialization.{add/get/clear}_safe_globals` that enables user to allowlist globals they have deemed safe and manipulate this list (for example they could allowlist a tensor subclass with a custom `__setstate__` if they have checked that this is safe).
Next steps:
- Add testing and allowlist required classes for all in-core tensor subclasses (e.g. `DTensor`, `FakeTensor` etc.)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124331
Approved by: https://github.com/albanD
2024-05-17 14:54:46 +00:00
|
|
|
# Additionally, users can use an allowlist for adding classes they have deemed as safe using
|
|
|
|
|
# `_add_safe_globals()` (`torch.serialization.add_safe_globals`)
|
|
|
|
|
# `_clear_safe_globals()` (`torch.serialization.clear_safe_globals`)
|
|
|
|
|
# `_get_safe_globals()` (`torch.serialization.get_safe_globals`)
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
|
|
|
|
|
# Based of https://github.com/python/cpython/blob/main/Lib/pickle.py
|
|
|
|
|
# Expected to be useful for loading PyTorch model weights
|
|
|
|
|
# For example:
|
|
|
|
|
# data = urllib.request.urlopen('https://download.pytorch.org/models/resnet50-0676ba61.pth').read()
|
|
|
|
|
# buf = io.BytesIO(data)
|
|
|
|
|
# weights = torch.load(buf, weights_only = True)
|
|
|
|
|
|
|
|
|
|
import functools as _functools
|
Fix allowlisting of builtins for weights_only unpickler (#129244)
Since we use [`DEFAULT_PROTOCOL=2`](https://github.com/pytorch/pytorch/blob/main/torch/serialization.py#L62), some functions/classes that were renamed from python 2-->3 will be pickled with their python2 name. This PR ensures that when a mod `GLOBAL <python2_mod>.<python2_name> ` is encountered, [following the strategy used by pickle](https://github.com/python/cpython/blob/main/Lib/pickle.py#L1590C13-L1593C63) it is properly mapped to `<python3_mod>.<python3_name>`.
This fix ensures that `add_safe_globals` works properly for such functions/classes (i.e. users will allowlist the python3 func and the weights_only unpickler will do the appropriate translation when checking whether a class was allowlisted).
An example is as follows:
`__builtin__` was named to `builtins`, see the [release notes for Python 3.0](https://docs.python.org/3/whatsnew/3.0.html)
> Renamed module `__builtin__` to [`builtins`](https://docs.python.org/3/library/builtins.html#module-builtins) (removing the underscores, adding an ‘s’). The __builtins__ variable found in most global namespaces is unchanged. To modify a builtin, you should use [builtins](https://docs.python.org/3/library/builtins.html#module-builtins), not `__builtins__`!
However, since we use [`DEFAULT_PROTOCOL=2`](https://github.com/pytorch/pytorch/blob/main/torch/serialization.py#L62), builtins will be pickled with their module string as `__builtin__`.
```python
>>> import pickle
>>> import pickletools
>>> print.__module__
'builtins'
>>> with open('print.pkl', 'wb') as f:
>>> pickle.dump(print, f, protocol=2) # 2 because this is the default protocol used by pytorch
>>> with open('print.pkl', 'rb') as f:
>>> pickletools.dis(f)
0: \x80 PROTO 2
2: c GLOBAL '__builtin__ print' # pickle saves the module string as __builtin__ !!! :(
21: q BINPUT 0
23: . STOP
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129244
Approved by: https://github.com/albanD
2024-06-25 01:08:26 +00:00
|
|
|
import warnings
|
2024-08-13 02:20:25 +00:00
|
|
|
|
|
|
|
|
from _codecs import encode
|
2024-04-16 17:00:50 +00:00
|
|
|
from collections import Counter, OrderedDict
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
from pickle import (
|
|
|
|
|
APPEND,
|
|
|
|
|
APPENDS,
|
2023-02-15 23:13:21 +00:00
|
|
|
BINFLOAT,
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
BINGET,
|
|
|
|
|
BININT,
|
|
|
|
|
BININT1,
|
|
|
|
|
BININT2,
|
|
|
|
|
BINPERSID,
|
|
|
|
|
BINPUT,
|
|
|
|
|
BINUNICODE,
|
|
|
|
|
BUILD,
|
|
|
|
|
bytes_types,
|
|
|
|
|
decode_long,
|
|
|
|
|
EMPTY_DICT,
|
|
|
|
|
EMPTY_LIST,
|
|
|
|
|
EMPTY_SET,
|
|
|
|
|
EMPTY_TUPLE,
|
|
|
|
|
GLOBAL,
|
|
|
|
|
LONG1,
|
|
|
|
|
LONG_BINGET,
|
|
|
|
|
LONG_BINPUT,
|
|
|
|
|
MARK,
|
|
|
|
|
NEWFALSE,
|
|
|
|
|
NEWOBJ,
|
|
|
|
|
NEWTRUE,
|
|
|
|
|
NONE,
|
|
|
|
|
PROTO,
|
|
|
|
|
REDUCE,
|
|
|
|
|
SETITEM,
|
|
|
|
|
SETITEMS,
|
|
|
|
|
SHORT_BINSTRING,
|
|
|
|
|
STOP,
|
|
|
|
|
TUPLE,
|
|
|
|
|
TUPLE1,
|
|
|
|
|
TUPLE2,
|
|
|
|
|
TUPLE3,
|
|
|
|
|
UnpicklingError,
|
|
|
|
|
)
|
|
|
|
|
from struct import unpack
|
2024-06-04 20:55:14 +00:00
|
|
|
from sys import maxsize
|
2025-01-21 21:42:12 +00:00
|
|
|
from typing import Any, Callable, Union
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
|
|
|
|
|
import torch
|
2024-06-26 02:56:05 +00:00
|
|
|
from torch._utils import IMPORT_MAPPING, NAME_MAPPING
|
|
|
|
|
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
|
2024-07-22 16:06:57 +00:00
|
|
|
# modules in this list are never allowed, even if the user attempts to allowlist
|
|
|
|
|
# functions/classes from them
|
|
|
|
|
_blocklisted_modules = [
|
|
|
|
|
"sys",
|
|
|
|
|
"os",
|
|
|
|
|
"posix",
|
|
|
|
|
"nt",
|
|
|
|
|
]
|
|
|
|
|
|
2025-01-21 21:42:12 +00:00
|
|
|
_marked_safe_globals_set: set[Union[Callable, tuple[Callable, str]]] = set()
|
Allow tensor subclasses and add `torch.serialization.add_safe_globals` that allows users to allowlist classes for `weights_only` load (#124331)
#### Conditions for allowlisting tensor subclasses
We allow tensor subclasses types that
(1) Do not override `__setstate__`, `__getattr__`, `__setattr__`, `__get__`, `__set__` or `__getattribute__` of `torch.Tensor` (`torch.Tensor` does not have a definition of `__getattr__`, `__get__` or `__set__` so we check that these are `None`)
(2) Use the generic `tp_alloc`
(3) Are in a module that *has been imported by the user*
to be pushed onto the stack as strings by `GLOBAL` instructions, while storing the type in a dict
The strings will be converted to the classes as appropriate when executing `REBUILD` with `_rebuild_from_type_v2`
*Note that we use `inspect.getattr_static(sys.modules[module], name)` to get the class/function as this method claims to have no code execution.
The rationale for the 3 conditions above is as follows:
The rebuild func provided by `Tensor.__reduce_ex__` is `torch._tensor._rebuild_from_type_v2`, which is defined as such (note the call to `getattr`, `Tensor.__setstate__` and the call to `as_subclass` as well as the call to `_set_obj_state` which calls `setattr`)
https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/_tensor.py#L57-L71
`as_subclass` is implemented with a call to `THPVariable_NewWithVar`
that will eventually call `tp_alloc` here
https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/csrc/autograd/python_variable.cpp#L2053
The `func` arg to `_rebuild_from_type_v2` for wrapper subclasses is `Tensor.rebuild_wrapper_subclass`, which will similarly call into `THPVariable_NewWithVar` and hit the above `tp_alloc`
**Note that we do not call `tp_init` or `tp_new` (i.e. `cls.__init__` or `cls.__new__`) when unpickling**
### How do we check something is a tensor subclass/constraints around imports
In order to check whether `bla` is a tensor subclass in the bytecode `GLOBAL module.name`, we need to do an `issubclass` check, which entails converting the global string to the appropriate type. We *do not* arbitrarily import modules but will perform this check as long as the given subclass (given by `module.name`) has already been imported by the user (i.e. `module in sys.modules` and `issubclass(getattr(sys[modules], name), torch.Tensor)`
This PR also allowlisted `torch._utils._rebuild_wrapper_subclass` and `torch.device` (used by `_rebuild_wrapper_subclass`)
### API for allow listing
This PR also added `torch.serialization.{add/get/clear}_safe_globals` that enables user to allowlist globals they have deemed safe and manipulate this list (for example they could allowlist a tensor subclass with a custom `__setstate__` if they have checked that this is safe).
Next steps:
- Add testing and allowlist required classes for all in-core tensor subclasses (e.g. `DTensor`, `FakeTensor` etc.)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124331
Approved by: https://github.com/albanD
2024-05-17 14:54:46 +00:00
|
|
|
|
|
|
|
|
|
2025-01-21 21:42:12 +00:00
|
|
|
def _add_safe_globals(safe_globals: list[Union[Callable, tuple[Callable, str]]]):
|
2024-11-02 01:58:05 +00:00
|
|
|
global _marked_safe_globals_set
|
|
|
|
|
_marked_safe_globals_set = _marked_safe_globals_set.union(set(safe_globals))
|
Allow tensor subclasses and add `torch.serialization.add_safe_globals` that allows users to allowlist classes for `weights_only` load (#124331)
#### Conditions for allowlisting tensor subclasses
We allow tensor subclasses types that
(1) Do not override `__setstate__`, `__getattr__`, `__setattr__`, `__get__`, `__set__` or `__getattribute__` of `torch.Tensor` (`torch.Tensor` does not have a definition of `__getattr__`, `__get__` or `__set__` so we check that these are `None`)
(2) Use the generic `tp_alloc`
(3) Are in a module that *has been imported by the user*
to be pushed onto the stack as strings by `GLOBAL` instructions, while storing the type in a dict
The strings will be converted to the classes as appropriate when executing `REBUILD` with `_rebuild_from_type_v2`
*Note that we use `inspect.getattr_static(sys.modules[module], name)` to get the class/function as this method claims to have no code execution.
The rationale for the 3 conditions above is as follows:
The rebuild func provided by `Tensor.__reduce_ex__` is `torch._tensor._rebuild_from_type_v2`, which is defined as such (note the call to `getattr`, `Tensor.__setstate__` and the call to `as_subclass` as well as the call to `_set_obj_state` which calls `setattr`)
https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/_tensor.py#L57-L71
`as_subclass` is implemented with a call to `THPVariable_NewWithVar`
that will eventually call `tp_alloc` here
https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/csrc/autograd/python_variable.cpp#L2053
The `func` arg to `_rebuild_from_type_v2` for wrapper subclasses is `Tensor.rebuild_wrapper_subclass`, which will similarly call into `THPVariable_NewWithVar` and hit the above `tp_alloc`
**Note that we do not call `tp_init` or `tp_new` (i.e. `cls.__init__` or `cls.__new__`) when unpickling**
### How do we check something is a tensor subclass/constraints around imports
In order to check whether `bla` is a tensor subclass in the bytecode `GLOBAL module.name`, we need to do an `issubclass` check, which entails converting the global string to the appropriate type. We *do not* arbitrarily import modules but will perform this check as long as the given subclass (given by `module.name`) has already been imported by the user (i.e. `module in sys.modules` and `issubclass(getattr(sys[modules], name), torch.Tensor)`
This PR also allowlisted `torch._utils._rebuild_wrapper_subclass` and `torch.device` (used by `_rebuild_wrapper_subclass`)
### API for allow listing
This PR also added `torch.serialization.{add/get/clear}_safe_globals` that enables user to allowlist globals they have deemed safe and manipulate this list (for example they could allowlist a tensor subclass with a custom `__setstate__` if they have checked that this is safe).
Next steps:
- Add testing and allowlist required classes for all in-core tensor subclasses (e.g. `DTensor`, `FakeTensor` etc.)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124331
Approved by: https://github.com/albanD
2024-05-17 14:54:46 +00:00
|
|
|
|
|
|
|
|
|
2025-01-21 21:42:12 +00:00
|
|
|
def _get_safe_globals() -> list[Union[Callable, tuple[Callable, str]]]:
|
2024-11-02 01:58:05 +00:00
|
|
|
global _marked_safe_globals_set
|
|
|
|
|
return list(_marked_safe_globals_set)
|
Allow tensor subclasses and add `torch.serialization.add_safe_globals` that allows users to allowlist classes for `weights_only` load (#124331)
#### Conditions for allowlisting tensor subclasses
We allow tensor subclasses types that
(1) Do not override `__setstate__`, `__getattr__`, `__setattr__`, `__get__`, `__set__` or `__getattribute__` of `torch.Tensor` (`torch.Tensor` does not have a definition of `__getattr__`, `__get__` or `__set__` so we check that these are `None`)
(2) Use the generic `tp_alloc`
(3) Are in a module that *has been imported by the user*
to be pushed onto the stack as strings by `GLOBAL` instructions, while storing the type in a dict
The strings will be converted to the classes as appropriate when executing `REBUILD` with `_rebuild_from_type_v2`
*Note that we use `inspect.getattr_static(sys.modules[module], name)` to get the class/function as this method claims to have no code execution.
The rationale for the 3 conditions above is as follows:
The rebuild func provided by `Tensor.__reduce_ex__` is `torch._tensor._rebuild_from_type_v2`, which is defined as such (note the call to `getattr`, `Tensor.__setstate__` and the call to `as_subclass` as well as the call to `_set_obj_state` which calls `setattr`)
https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/_tensor.py#L57-L71
`as_subclass` is implemented with a call to `THPVariable_NewWithVar`
that will eventually call `tp_alloc` here
https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/csrc/autograd/python_variable.cpp#L2053
The `func` arg to `_rebuild_from_type_v2` for wrapper subclasses is `Tensor.rebuild_wrapper_subclass`, which will similarly call into `THPVariable_NewWithVar` and hit the above `tp_alloc`
**Note that we do not call `tp_init` or `tp_new` (i.e. `cls.__init__` or `cls.__new__`) when unpickling**
### How do we check something is a tensor subclass/constraints around imports
In order to check whether `bla` is a tensor subclass in the bytecode `GLOBAL module.name`, we need to do an `issubclass` check, which entails converting the global string to the appropriate type. We *do not* arbitrarily import modules but will perform this check as long as the given subclass (given by `module.name`) has already been imported by the user (i.e. `module in sys.modules` and `issubclass(getattr(sys[modules], name), torch.Tensor)`
This PR also allowlisted `torch._utils._rebuild_wrapper_subclass` and `torch.device` (used by `_rebuild_wrapper_subclass`)
### API for allow listing
This PR also added `torch.serialization.{add/get/clear}_safe_globals` that enables user to allowlist globals they have deemed safe and manipulate this list (for example they could allowlist a tensor subclass with a custom `__setstate__` if they have checked that this is safe).
Next steps:
- Add testing and allowlist required classes for all in-core tensor subclasses (e.g. `DTensor`, `FakeTensor` etc.)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124331
Approved by: https://github.com/albanD
2024-05-17 14:54:46 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def _clear_safe_globals():
|
2024-11-02 01:58:05 +00:00
|
|
|
global _marked_safe_globals_set
|
|
|
|
|
_marked_safe_globals_set = set()
|
Allow tensor subclasses and add `torch.serialization.add_safe_globals` that allows users to allowlist classes for `weights_only` load (#124331)
#### Conditions for allowlisting tensor subclasses
We allow tensor subclasses types that
(1) Do not override `__setstate__`, `__getattr__`, `__setattr__`, `__get__`, `__set__` or `__getattribute__` of `torch.Tensor` (`torch.Tensor` does not have a definition of `__getattr__`, `__get__` or `__set__` so we check that these are `None`)
(2) Use the generic `tp_alloc`
(3) Are in a module that *has been imported by the user*
to be pushed onto the stack as strings by `GLOBAL` instructions, while storing the type in a dict
The strings will be converted to the classes as appropriate when executing `REBUILD` with `_rebuild_from_type_v2`
*Note that we use `inspect.getattr_static(sys.modules[module], name)` to get the class/function as this method claims to have no code execution.
The rationale for the 3 conditions above is as follows:
The rebuild func provided by `Tensor.__reduce_ex__` is `torch._tensor._rebuild_from_type_v2`, which is defined as such (note the call to `getattr`, `Tensor.__setstate__` and the call to `as_subclass` as well as the call to `_set_obj_state` which calls `setattr`)
https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/_tensor.py#L57-L71
`as_subclass` is implemented with a call to `THPVariable_NewWithVar`
that will eventually call `tp_alloc` here
https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/csrc/autograd/python_variable.cpp#L2053
The `func` arg to `_rebuild_from_type_v2` for wrapper subclasses is `Tensor.rebuild_wrapper_subclass`, which will similarly call into `THPVariable_NewWithVar` and hit the above `tp_alloc`
**Note that we do not call `tp_init` or `tp_new` (i.e. `cls.__init__` or `cls.__new__`) when unpickling**
### How do we check something is a tensor subclass/constraints around imports
In order to check whether `bla` is a tensor subclass in the bytecode `GLOBAL module.name`, we need to do an `issubclass` check, which entails converting the global string to the appropriate type. We *do not* arbitrarily import modules but will perform this check as long as the given subclass (given by `module.name`) has already been imported by the user (i.e. `module in sys.modules` and `issubclass(getattr(sys[modules], name), torch.Tensor)`
This PR also allowlisted `torch._utils._rebuild_wrapper_subclass` and `torch.device` (used by `_rebuild_wrapper_subclass`)
### API for allow listing
This PR also added `torch.serialization.{add/get/clear}_safe_globals` that enables user to allowlist globals they have deemed safe and manipulate this list (for example they could allowlist a tensor subclass with a custom `__setstate__` if they have checked that this is safe).
Next steps:
- Add testing and allowlist required classes for all in-core tensor subclasses (e.g. `DTensor`, `FakeTensor` etc.)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124331
Approved by: https://github.com/albanD
2024-05-17 14:54:46 +00:00
|
|
|
|
|
|
|
|
|
2024-12-05 22:48:40 +00:00
|
|
|
def _remove_safe_globals(
|
2025-01-21 21:42:12 +00:00
|
|
|
globals_to_remove: list[Union[Callable, tuple[Callable, str]]],
|
2024-12-05 22:48:40 +00:00
|
|
|
):
|
2024-11-02 01:58:05 +00:00
|
|
|
global _marked_safe_globals_set
|
|
|
|
|
_marked_safe_globals_set = _marked_safe_globals_set - set(globals_to_remove)
|
2024-07-12 16:16:10 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class _safe_globals:
|
2025-01-21 21:42:12 +00:00
|
|
|
def __init__(self, safe_globals: list[Union[Callable, tuple[Callable, str]]]):
|
2024-07-12 16:16:10 +00:00
|
|
|
self.safe_globals = safe_globals
|
|
|
|
|
|
|
|
|
|
def __enter__(self):
|
|
|
|
|
_add_safe_globals(self.safe_globals)
|
|
|
|
|
|
|
|
|
|
def __exit__(self, type, value, tb):
|
|
|
|
|
_remove_safe_globals(self.safe_globals)
|
|
|
|
|
|
|
|
|
|
|
Allow tensor subclasses and add `torch.serialization.add_safe_globals` that allows users to allowlist classes for `weights_only` load (#124331)
#### Conditions for allowlisting tensor subclasses
We allow tensor subclasses types that
(1) Do not override `__setstate__`, `__getattr__`, `__setattr__`, `__get__`, `__set__` or `__getattribute__` of `torch.Tensor` (`torch.Tensor` does not have a definition of `__getattr__`, `__get__` or `__set__` so we check that these are `None`)
(2) Use the generic `tp_alloc`
(3) Are in a module that *has been imported by the user*
to be pushed onto the stack as strings by `GLOBAL` instructions, while storing the type in a dict
The strings will be converted to the classes as appropriate when executing `REBUILD` with `_rebuild_from_type_v2`
*Note that we use `inspect.getattr_static(sys.modules[module], name)` to get the class/function as this method claims to have no code execution.
The rationale for the 3 conditions above is as follows:
The rebuild func provided by `Tensor.__reduce_ex__` is `torch._tensor._rebuild_from_type_v2`, which is defined as such (note the call to `getattr`, `Tensor.__setstate__` and the call to `as_subclass` as well as the call to `_set_obj_state` which calls `setattr`)
https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/_tensor.py#L57-L71
`as_subclass` is implemented with a call to `THPVariable_NewWithVar`
that will eventually call `tp_alloc` here
https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/csrc/autograd/python_variable.cpp#L2053
The `func` arg to `_rebuild_from_type_v2` for wrapper subclasses is `Tensor.rebuild_wrapper_subclass`, which will similarly call into `THPVariable_NewWithVar` and hit the above `tp_alloc`
**Note that we do not call `tp_init` or `tp_new` (i.e. `cls.__init__` or `cls.__new__`) when unpickling**
### How do we check something is a tensor subclass/constraints around imports
In order to check whether `bla` is a tensor subclass in the bytecode `GLOBAL module.name`, we need to do an `issubclass` check, which entails converting the global string to the appropriate type. We *do not* arbitrarily import modules but will perform this check as long as the given subclass (given by `module.name`) has already been imported by the user (i.e. `module in sys.modules` and `issubclass(getattr(sys[modules], name), torch.Tensor)`
This PR also allowlisted `torch._utils._rebuild_wrapper_subclass` and `torch.device` (used by `_rebuild_wrapper_subclass`)
### API for allow listing
This PR also added `torch.serialization.{add/get/clear}_safe_globals` that enables user to allowlist globals they have deemed safe and manipulate this list (for example they could allowlist a tensor subclass with a custom `__setstate__` if they have checked that this is safe).
Next steps:
- Add testing and allowlist required classes for all in-core tensor subclasses (e.g. `DTensor`, `FakeTensor` etc.)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124331
Approved by: https://github.com/albanD
2024-05-17 14:54:46 +00:00
|
|
|
# Separate from _get_allowed_globals because of the lru_cache on _get_allowed_globals
|
|
|
|
|
# For example if user had a script like
|
|
|
|
|
# torch.load(file_a)
|
|
|
|
|
# torch.serialization._add_safe_globals([torch.foo])
|
|
|
|
|
# torch.load(file_b)
|
|
|
|
|
# the dynamic additions to safe_globals would not be picked up by
|
|
|
|
|
# _get_allowed_globals due to the lru_cache
|
|
|
|
|
def _get_user_allowed_globals():
|
2025-01-21 21:42:12 +00:00
|
|
|
rc: dict[str, Any] = {}
|
2024-11-02 01:58:05 +00:00
|
|
|
for f in _marked_safe_globals_set:
|
2024-12-05 22:48:40 +00:00
|
|
|
if isinstance(f, tuple):
|
|
|
|
|
if len(f) != 2:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Expected tuple of length 2 (global, str of callable full path), but got tuple of length: {len(f)}"
|
|
|
|
|
)
|
|
|
|
|
if type(f[1]) is not str:
|
|
|
|
|
raise TypeError(
|
|
|
|
|
f"Expected second item in tuple to be str of callable full path, but got: {type(f[1])}"
|
|
|
|
|
)
|
|
|
|
|
f, name = f
|
|
|
|
|
rc[name] = f
|
|
|
|
|
else:
|
|
|
|
|
module, name = f.__module__, f.__name__
|
|
|
|
|
rc[f"{module}.{name}"] = f
|
Allow tensor subclasses and add `torch.serialization.add_safe_globals` that allows users to allowlist classes for `weights_only` load (#124331)
#### Conditions for allowlisting tensor subclasses
We allow tensor subclasses types that
(1) Do not override `__setstate__`, `__getattr__`, `__setattr__`, `__get__`, `__set__` or `__getattribute__` of `torch.Tensor` (`torch.Tensor` does not have a definition of `__getattr__`, `__get__` or `__set__` so we check that these are `None`)
(2) Use the generic `tp_alloc`
(3) Are in a module that *has been imported by the user*
to be pushed onto the stack as strings by `GLOBAL` instructions, while storing the type in a dict
The strings will be converted to the classes as appropriate when executing `REBUILD` with `_rebuild_from_type_v2`
*Note that we use `inspect.getattr_static(sys.modules[module], name)` to get the class/function as this method claims to have no code execution.
The rationale for the 3 conditions above is as follows:
The rebuild func provided by `Tensor.__reduce_ex__` is `torch._tensor._rebuild_from_type_v2`, which is defined as such (note the call to `getattr`, `Tensor.__setstate__` and the call to `as_subclass` as well as the call to `_set_obj_state` which calls `setattr`)
https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/_tensor.py#L57-L71
`as_subclass` is implemented with a call to `THPVariable_NewWithVar`
that will eventually call `tp_alloc` here
https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/csrc/autograd/python_variable.cpp#L2053
The `func` arg to `_rebuild_from_type_v2` for wrapper subclasses is `Tensor.rebuild_wrapper_subclass`, which will similarly call into `THPVariable_NewWithVar` and hit the above `tp_alloc`
**Note that we do not call `tp_init` or `tp_new` (i.e. `cls.__init__` or `cls.__new__`) when unpickling**
### How do we check something is a tensor subclass/constraints around imports
In order to check whether `bla` is a tensor subclass in the bytecode `GLOBAL module.name`, we need to do an `issubclass` check, which entails converting the global string to the appropriate type. We *do not* arbitrarily import modules but will perform this check as long as the given subclass (given by `module.name`) has already been imported by the user (i.e. `module in sys.modules` and `issubclass(getattr(sys[modules], name), torch.Tensor)`
This PR also allowlisted `torch._utils._rebuild_wrapper_subclass` and `torch.device` (used by `_rebuild_wrapper_subclass`)
### API for allow listing
This PR also added `torch.serialization.{add/get/clear}_safe_globals` that enables user to allowlist globals they have deemed safe and manipulate this list (for example they could allowlist a tensor subclass with a custom `__setstate__` if they have checked that this is safe).
Next steps:
- Add testing and allowlist required classes for all in-core tensor subclasses (e.g. `DTensor`, `FakeTensor` etc.)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124331
Approved by: https://github.com/albanD
2024-05-17 14:54:46 +00:00
|
|
|
return rc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _tensor_rebuild_functions():
|
|
|
|
|
return {
|
|
|
|
|
torch._utils._rebuild_parameter,
|
|
|
|
|
torch._utils._rebuild_parameter_with_state,
|
|
|
|
|
torch._utils._rebuild_qtensor,
|
|
|
|
|
torch._utils._rebuild_tensor,
|
|
|
|
|
torch._utils._rebuild_tensor_v2,
|
|
|
|
|
torch._utils._rebuild_tensor_v3,
|
|
|
|
|
torch._utils._rebuild_sparse_tensor,
|
|
|
|
|
torch._utils._rebuild_meta_tensor_no_storage,
|
|
|
|
|
torch._utils._rebuild_nested_tensor,
|
|
|
|
|
torch._utils._rebuild_wrapper_subclass,
|
2024-08-15 19:48:35 +00:00
|
|
|
# Allowlisting this, but not allowlisting the numpy functions by default
|
|
|
|
|
# Reasoning is that we don't have control over the numpy functions, but
|
|
|
|
|
# this utility is provided by pytorch
|
|
|
|
|
torch._utils._rebuild_device_tensor_from_numpy,
|
2024-10-09 16:02:05 +00:00
|
|
|
# In 2.6, we should no longer have a dependency on numpy and the above
|
|
|
|
|
# _rebuild_device_tensor_from_numpy function.
|
|
|
|
|
torch._utils._rebuild_device_tensor_from_cpu_tensor,
|
Allow tensor subclasses and add `torch.serialization.add_safe_globals` that allows users to allowlist classes for `weights_only` load (#124331)
#### Conditions for allowlisting tensor subclasses
We allow tensor subclasses types that
(1) Do not override `__setstate__`, `__getattr__`, `__setattr__`, `__get__`, `__set__` or `__getattribute__` of `torch.Tensor` (`torch.Tensor` does not have a definition of `__getattr__`, `__get__` or `__set__` so we check that these are `None`)
(2) Use the generic `tp_alloc`
(3) Are in a module that *has been imported by the user*
to be pushed onto the stack as strings by `GLOBAL` instructions, while storing the type in a dict
The strings will be converted to the classes as appropriate when executing `REBUILD` with `_rebuild_from_type_v2`
*Note that we use `inspect.getattr_static(sys.modules[module], name)` to get the class/function as this method claims to have no code execution.
The rationale for the 3 conditions above is as follows:
The rebuild func provided by `Tensor.__reduce_ex__` is `torch._tensor._rebuild_from_type_v2`, which is defined as such (note the call to `getattr`, `Tensor.__setstate__` and the call to `as_subclass` as well as the call to `_set_obj_state` which calls `setattr`)
https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/_tensor.py#L57-L71
`as_subclass` is implemented with a call to `THPVariable_NewWithVar`
that will eventually call `tp_alloc` here
https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/csrc/autograd/python_variable.cpp#L2053
The `func` arg to `_rebuild_from_type_v2` for wrapper subclasses is `Tensor.rebuild_wrapper_subclass`, which will similarly call into `THPVariable_NewWithVar` and hit the above `tp_alloc`
**Note that we do not call `tp_init` or `tp_new` (i.e. `cls.__init__` or `cls.__new__`) when unpickling**
### How do we check something is a tensor subclass/constraints around imports
In order to check whether `bla` is a tensor subclass in the bytecode `GLOBAL module.name`, we need to do an `issubclass` check, which entails converting the global string to the appropriate type. We *do not* arbitrarily import modules but will perform this check as long as the given subclass (given by `module.name`) has already been imported by the user (i.e. `module in sys.modules` and `issubclass(getattr(sys[modules], name), torch.Tensor)`
This PR also allowlisted `torch._utils._rebuild_wrapper_subclass` and `torch.device` (used by `_rebuild_wrapper_subclass`)
### API for allow listing
This PR also added `torch.serialization.{add/get/clear}_safe_globals` that enables user to allowlist globals they have deemed safe and manipulate this list (for example they could allowlist a tensor subclass with a custom `__setstate__` if they have checked that this is safe).
Next steps:
- Add testing and allowlist required classes for all in-core tensor subclasses (e.g. `DTensor`, `FakeTensor` etc.)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124331
Approved by: https://github.com/albanD
2024-05-17 14:54:46 +00:00
|
|
|
}
|
|
|
|
|
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
|
|
|
|
|
# Unpickling machinery
|
|
|
|
|
@_functools.lru_cache(maxsize=1)
|
|
|
|
|
def _get_allowed_globals():
|
2025-01-21 21:42:12 +00:00
|
|
|
rc: dict[str, Any] = {
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
"collections.OrderedDict": OrderedDict,
|
2024-04-16 17:00:50 +00:00
|
|
|
"collections.Counter": Counter,
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
"torch.nn.parameter.Parameter": torch.nn.Parameter,
|
|
|
|
|
"torch.serialization._get_layout": torch.serialization._get_layout,
|
|
|
|
|
"torch.Size": torch.Size,
|
|
|
|
|
"torch.Tensor": torch.Tensor,
|
Allow tensor subclasses and add `torch.serialization.add_safe_globals` that allows users to allowlist classes for `weights_only` load (#124331)
#### Conditions for allowlisting tensor subclasses
We allow tensor subclasses types that
(1) Do not override `__setstate__`, `__getattr__`, `__setattr__`, `__get__`, `__set__` or `__getattribute__` of `torch.Tensor` (`torch.Tensor` does not have a definition of `__getattr__`, `__get__` or `__set__` so we check that these are `None`)
(2) Use the generic `tp_alloc`
(3) Are in a module that *has been imported by the user*
to be pushed onto the stack as strings by `GLOBAL` instructions, while storing the type in a dict
The strings will be converted to the classes as appropriate when executing `REBUILD` with `_rebuild_from_type_v2`
*Note that we use `inspect.getattr_static(sys.modules[module], name)` to get the class/function as this method claims to have no code execution.
The rationale for the 3 conditions above is as follows:
The rebuild func provided by `Tensor.__reduce_ex__` is `torch._tensor._rebuild_from_type_v2`, which is defined as such (note the call to `getattr`, `Tensor.__setstate__` and the call to `as_subclass` as well as the call to `_set_obj_state` which calls `setattr`)
https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/_tensor.py#L57-L71
`as_subclass` is implemented with a call to `THPVariable_NewWithVar`
that will eventually call `tp_alloc` here
https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/csrc/autograd/python_variable.cpp#L2053
The `func` arg to `_rebuild_from_type_v2` for wrapper subclasses is `Tensor.rebuild_wrapper_subclass`, which will similarly call into `THPVariable_NewWithVar` and hit the above `tp_alloc`
**Note that we do not call `tp_init` or `tp_new` (i.e. `cls.__init__` or `cls.__new__`) when unpickling**
### How do we check something is a tensor subclass/constraints around imports
In order to check whether `bla` is a tensor subclass in the bytecode `GLOBAL module.name`, we need to do an `issubclass` check, which entails converting the global string to the appropriate type. We *do not* arbitrarily import modules but will perform this check as long as the given subclass (given by `module.name`) has already been imported by the user (i.e. `module in sys.modules` and `issubclass(getattr(sys[modules], name), torch.Tensor)`
This PR also allowlisted `torch._utils._rebuild_wrapper_subclass` and `torch.device` (used by `_rebuild_wrapper_subclass`)
### API for allow listing
This PR also added `torch.serialization.{add/get/clear}_safe_globals` that enables user to allowlist globals they have deemed safe and manipulate this list (for example they could allowlist a tensor subclass with a custom `__setstate__` if they have checked that this is safe).
Next steps:
- Add testing and allowlist required classes for all in-core tensor subclasses (e.g. `DTensor`, `FakeTensor` etc.)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124331
Approved by: https://github.com/albanD
2024-05-17 14:54:46 +00:00
|
|
|
"torch.device": torch.device,
|
2024-08-13 02:20:25 +00:00
|
|
|
"_codecs.encode": encode, # for bytes
|
|
|
|
|
"builtins.bytearray": bytearray, # for bytearray
|
2024-10-25 05:23:08 +00:00
|
|
|
"builtins.set": set, # for set
|
2024-11-18 22:03:09 +00:00
|
|
|
"builtins.complex": complex, # for complex
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
}
|
2024-11-18 22:03:08 +00:00
|
|
|
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
# dtype
|
2024-04-22 23:44:30 +00:00
|
|
|
for t in torch.storage._dtype_to_storage_type_map().keys():
|
|
|
|
|
rc[str(t)] = t
|
|
|
|
|
for t in torch.storage._new_dtypes():
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
rc[str(t)] = t
|
|
|
|
|
# Tensor classes
|
|
|
|
|
for tt in torch._tensor_classes:
|
|
|
|
|
rc[f"{tt.__module__}.{tt.__name__}"] = tt
|
|
|
|
|
# Storage classes
|
|
|
|
|
for ts in torch._storage_classes:
|
[BE] Do not warn when safely loading legacy dicts (#113614)
Use the same strategy as for unsafe pickler, i.e. use dummy `torch.serialization.StorageType` to represent legacy typed storage classes during deserialization. Add `_dtype` property to be able to use it for both new and legacy format deserialization.
Parametrize `test_serialization_new_format_old_format_compat`
Add regression test to validate that loading legacy modes can be done
without any warnings
Before the change:
```
% python test_serialization.py -v -k test_serialization_new_format_old_format_compat_
test_serialization_new_format_old_format_compat_cpu (__main__.TestBothSerializationCPU) ... ok
test_serialization_new_format_old_format_compat_safe_cpu (__main__.TestBothSerializationCPU) ... /Users/nshulga/git/pytorch/pytorch/torch/_utils.py:836: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
return self.fget.__get__(instance, owner)()
ok
----------------------------------------------------------------------
Ran 2 tests in 0.116s
OK
```
Without the change but update test to catch warnings:
```
% python test_serialization.py -v -k test_serialization_new_format_old_format_compat_
test_serialization_new_format_old_format_compat_weights_only_False_cpu (__main__.TestBothSerializationCPU) ... ok
test_serialization_new_format_old_format_compat_weights_only_True_cpu (__main__.TestBothSerializationCPU) ... FAIL
======================================================================
FAIL: test_serialization_new_format_old_format_compat_weights_only_True_cpu (__main__.TestBothSerializationCPU)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/Users/nshulga/git/pytorch/pytorch/torch/testing/_internal/common_utils.py", line 2536, in wrapper
method(*args, **kwargs)
File "/Users/nshulga/git/pytorch/pytorch/torch/testing/_internal/common_device_type.py", line 415, in instantiated_test
result = test(self, **param_kwargs)
File "/Users/nshulga/git/pytorch/pytorch/test/test_serialization.py", line 807, in test_serialization_new_format_old_format_compat
self.assertTrue(len(w) == 0, msg=f"Expected no warnings but got {[str(x) for x in w]}")
AssertionError: False is not true : Expected no warnings but got ["{message : UserWarning('TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()'), category : 'UserWarning', filename : '/Users/nshulga/git/pytorch/pytorch/torch/_utils.py', lineno : 836, line : None}"]
To execute this test, run the following from the base repo dir:
python test/test_serialization.py -k test_serialization_new_format_old_format_compat_weights_only_True_cpu
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
----------------------------------------------------------------------
Ran 2 tests in 0.109s
FAILED (failures=1)
```
Fixes problem reported in https://github.com/pytorch/pytorch/issues/52181#issuecomment-1715738910
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113614
Approved by: https://github.com/kit1980, https://github.com/albanD
2023-11-14 22:09:10 +00:00
|
|
|
if ts not in (torch.storage.TypedStorage, torch.storage.UntypedStorage):
|
|
|
|
|
# Wrap legacy storage types in a dummy class
|
|
|
|
|
rc[f"{ts.__module__}.{ts.__name__}"] = torch.serialization.StorageType(
|
|
|
|
|
ts.__name__
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
rc[f"{ts.__module__}.{ts.__name__}"] = ts
|
2024-04-22 23:44:30 +00:00
|
|
|
# Quantization specific
|
|
|
|
|
for qt in [
|
|
|
|
|
torch.per_tensor_affine,
|
|
|
|
|
torch.per_tensor_symmetric,
|
|
|
|
|
torch.per_channel_affine,
|
|
|
|
|
torch.per_channel_symmetric,
|
|
|
|
|
torch.per_channel_affine_float_qparams,
|
|
|
|
|
]:
|
|
|
|
|
rc[str(qt)] = qt
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
# Rebuild functions
|
Allow tensor subclasses and add `torch.serialization.add_safe_globals` that allows users to allowlist classes for `weights_only` load (#124331)
#### Conditions for allowlisting tensor subclasses
We allow tensor subclasses types that
(1) Do not override `__setstate__`, `__getattr__`, `__setattr__`, `__get__`, `__set__` or `__getattribute__` of `torch.Tensor` (`torch.Tensor` does not have a definition of `__getattr__`, `__get__` or `__set__` so we check that these are `None`)
(2) Use the generic `tp_alloc`
(3) Are in a module that *has been imported by the user*
to be pushed onto the stack as strings by `GLOBAL` instructions, while storing the type in a dict
The strings will be converted to the classes as appropriate when executing `REBUILD` with `_rebuild_from_type_v2`
*Note that we use `inspect.getattr_static(sys.modules[module], name)` to get the class/function as this method claims to have no code execution.
The rationale for the 3 conditions above is as follows:
The rebuild func provided by `Tensor.__reduce_ex__` is `torch._tensor._rebuild_from_type_v2`, which is defined as such (note the call to `getattr`, `Tensor.__setstate__` and the call to `as_subclass` as well as the call to `_set_obj_state` which calls `setattr`)
https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/_tensor.py#L57-L71
`as_subclass` is implemented with a call to `THPVariable_NewWithVar`
that will eventually call `tp_alloc` here
https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/csrc/autograd/python_variable.cpp#L2053
The `func` arg to `_rebuild_from_type_v2` for wrapper subclasses is `Tensor.rebuild_wrapper_subclass`, which will similarly call into `THPVariable_NewWithVar` and hit the above `tp_alloc`
**Note that we do not call `tp_init` or `tp_new` (i.e. `cls.__init__` or `cls.__new__`) when unpickling**
### How do we check something is a tensor subclass/constraints around imports
In order to check whether `bla` is a tensor subclass in the bytecode `GLOBAL module.name`, we need to do an `issubclass` check, which entails converting the global string to the appropriate type. We *do not* arbitrarily import modules but will perform this check as long as the given subclass (given by `module.name`) has already been imported by the user (i.e. `module in sys.modules` and `issubclass(getattr(sys[modules], name), torch.Tensor)`
This PR also allowlisted `torch._utils._rebuild_wrapper_subclass` and `torch.device` (used by `_rebuild_wrapper_subclass`)
### API for allow listing
This PR also added `torch.serialization.{add/get/clear}_safe_globals` that enables user to allowlist globals they have deemed safe and manipulate this list (for example they could allowlist a tensor subclass with a custom `__setstate__` if they have checked that this is safe).
Next steps:
- Add testing and allowlist required classes for all in-core tensor subclasses (e.g. `DTensor`, `FakeTensor` etc.)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124331
Approved by: https://github.com/albanD
2024-05-17 14:54:46 +00:00
|
|
|
for f in _tensor_rebuild_functions():
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
rc[f"torch._utils.{f.__name__}"] = f
|
[fix] allow saving python attr on Tensor and Parameter via torch.save (#81616)
Fixes: https://github.com/pytorch/pytorch/issues/72129
TODO:
* [x] Fix for Parameter
Benchmark
(Measurable diff for small tensors)
```
[-------------- Save and Load --------------]
| After PR | Before PR
1 threads: ----------------------------------
() | 111.7 | 106.9
(4, 4) | 114.4 | 109.2
(128, 128) | 135.2 | 128.3
(1024, 1024) | 1431.9 | 1431.3
Times are in microseconds (us).
```
<details>
<summary> Benchmark Script </summary>
```python
import torch
from torch.testing._internal.common_utils import BytesIOContext
from torch.utils import benchmark
import pickle
shapes = ((), (4, 4), (128, 128), (1024, 1024))
sizes = [1, 64, 1024, 10000]
results = []
def save_load_fn(t):
with BytesIOContext() as f:
torch.save(t, f)
f.seek(0)
torch.load(f)
for shape in shapes:
t = torch.randn(shape)
label = 'Save and Load'
sub_label = f'{shape}'
results.append(benchmark.Timer(
stmt='save_load_fn(t)',
globals={'t': t, 'save_load_fn':save_load_fn},
label=label,
sub_label=sub_label,
description='Before PR',
).blocked_autorange(min_run_time=2))
compare = benchmark.Compare(results)
compare.print()
with open('before_pr.pkl', 'wb') as f:
pickle.dump(results, f)
# with open('after_pr.pkl', 'rb') as f:
# after_pr = pickle.load(f)
# with open('before_pr.pkl', 'rb') as f:
# before_pr = pickle.load(f)
# compare = benchmark.Compare(after_pr + before_pr)
# compare.print()
```
</details>
NOTE : **BC-Breaking** : After this PR, all tensors (also regular tensors) will be serialised using `_rebuild_from_type_v2`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81616
Approved by: https://github.com/albanD, https://github.com/kurtamohler
2022-11-11 21:11:12 +00:00
|
|
|
|
|
|
|
|
# Handles Tensor Subclasses, Tensor's with attributes.
|
|
|
|
|
# NOTE: It calls into above rebuild functions for regular Tensor types.
|
|
|
|
|
rc["torch._tensor._rebuild_from_type_v2"] = torch._tensor._rebuild_from_type_v2
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
return rc
|
|
|
|
|
|
|
|
|
|
|
2025-01-21 21:42:12 +00:00
|
|
|
def _read_global_instruction(readline: Callable) -> tuple[str, str]:
|
2024-11-01 16:50:25 +00:00
|
|
|
module = readline()[:-1].decode("utf-8")
|
|
|
|
|
name = readline()[:-1].decode("utf-8")
|
|
|
|
|
# Patch since torch.save default protocol is 2
|
|
|
|
|
# users will be running this code in python > 3
|
|
|
|
|
if (module, name) in NAME_MAPPING:
|
|
|
|
|
module, name = NAME_MAPPING[(module, name)]
|
|
|
|
|
elif module in IMPORT_MAPPING:
|
|
|
|
|
module = IMPORT_MAPPING[module]
|
|
|
|
|
return module, name
|
|
|
|
|
|
|
|
|
|
|
2025-01-21 21:42:12 +00:00
|
|
|
def get_globals_in_pkl(file) -> set[str]:
|
2024-11-01 16:50:25 +00:00
|
|
|
globals_in_checkpoint = set()
|
|
|
|
|
read = file.read
|
|
|
|
|
readline = file.readline
|
|
|
|
|
op_to_bytes_to_read = {
|
|
|
|
|
NEWOBJ[0]: 0,
|
|
|
|
|
REDUCE[0]: 0,
|
|
|
|
|
BUILD[0]: 0,
|
|
|
|
|
APPEND[0]: 0,
|
|
|
|
|
APPENDS[0]: 0,
|
|
|
|
|
SETITEM[0]: 0,
|
|
|
|
|
SETITEMS[0]: 0,
|
|
|
|
|
MARK[0]: 0,
|
|
|
|
|
TUPLE[0]: 0,
|
|
|
|
|
TUPLE1[0]: 0,
|
|
|
|
|
TUPLE2[0]: 0,
|
|
|
|
|
TUPLE3[0]: 0,
|
|
|
|
|
NONE[0]: 0,
|
|
|
|
|
NEWFALSE[0]: 0,
|
|
|
|
|
NEWTRUE[0]: 0,
|
|
|
|
|
EMPTY_TUPLE[0]: 0,
|
|
|
|
|
EMPTY_LIST[0]: 0,
|
|
|
|
|
EMPTY_DICT[0]: 0,
|
|
|
|
|
EMPTY_SET[0]: 0,
|
|
|
|
|
BINPERSID[0]: 0,
|
|
|
|
|
BININT[0]: 4,
|
|
|
|
|
BININT1[0]: 1,
|
|
|
|
|
BININT2[0]: 2,
|
|
|
|
|
BINFLOAT[0]: 8,
|
|
|
|
|
BINGET[0]: 1,
|
|
|
|
|
LONG_BINGET[0]: 4,
|
|
|
|
|
BINPUT[0]: 1,
|
|
|
|
|
LONG_BINPUT[0]: 4,
|
|
|
|
|
}
|
|
|
|
|
while True:
|
|
|
|
|
key = read(1)
|
|
|
|
|
if not key:
|
|
|
|
|
raise EOFError
|
|
|
|
|
assert isinstance(key, bytes_types)
|
|
|
|
|
if key[0] == GLOBAL[0]:
|
|
|
|
|
module, name = _read_global_instruction(readline)
|
|
|
|
|
globals_in_checkpoint.add(f"{module}.{name}")
|
|
|
|
|
elif key[0] in op_to_bytes_to_read:
|
|
|
|
|
bytes_to_read = op_to_bytes_to_read[key[0]]
|
|
|
|
|
if bytes_to_read:
|
|
|
|
|
read(bytes_to_read)
|
|
|
|
|
# ops where bytes to read depends on the data
|
|
|
|
|
elif key[0] == BINUNICODE[0]:
|
|
|
|
|
strlen = unpack("<I", read(4))[0]
|
|
|
|
|
if strlen > maxsize:
|
|
|
|
|
raise UnpicklingError("String is too long")
|
|
|
|
|
read(strlen)
|
|
|
|
|
elif key[0] in {SHORT_BINSTRING[0], LONG1[0]}:
|
|
|
|
|
strlen = read(1)[0]
|
|
|
|
|
read(strlen)
|
|
|
|
|
# first and last op
|
|
|
|
|
elif key[0] == PROTO[0]:
|
2024-12-12 12:11:20 +00:00
|
|
|
read(1)[0]
|
2024-11-01 16:50:25 +00:00
|
|
|
elif key[0] == STOP[0]:
|
|
|
|
|
return globals_in_checkpoint
|
|
|
|
|
else:
|
|
|
|
|
raise UnpicklingError(f"Unsupported operand {key[0]}")
|
|
|
|
|
|
|
|
|
|
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
class Unpickler:
|
|
|
|
|
def __init__(self, file, *, encoding: str = "bytes"):
|
|
|
|
|
self.encoding = encoding
|
|
|
|
|
self.readline = file.readline
|
|
|
|
|
self.read = file.read
|
2025-01-21 21:42:12 +00:00
|
|
|
self.memo: dict[int, Any] = {}
|
Fix allowlisting of builtins for weights_only unpickler (#129244)
Since we use [`DEFAULT_PROTOCOL=2`](https://github.com/pytorch/pytorch/blob/main/torch/serialization.py#L62), some functions/classes that were renamed from python 2-->3 will be pickled with their python2 name. This PR ensures that when a mod `GLOBAL <python2_mod>.<python2_name> ` is encountered, [following the strategy used by pickle](https://github.com/python/cpython/blob/main/Lib/pickle.py#L1590C13-L1593C63) it is properly mapped to `<python3_mod>.<python3_name>`.
This fix ensures that `add_safe_globals` works properly for such functions/classes (i.e. users will allowlist the python3 func and the weights_only unpickler will do the appropriate translation when checking whether a class was allowlisted).
An example is as follows:
`__builtin__` was named to `builtins`, see the [release notes for Python 3.0](https://docs.python.org/3/whatsnew/3.0.html)
> Renamed module `__builtin__` to [`builtins`](https://docs.python.org/3/library/builtins.html#module-builtins) (removing the underscores, adding an ‘s’). The __builtins__ variable found in most global namespaces is unchanged. To modify a builtin, you should use [builtins](https://docs.python.org/3/library/builtins.html#module-builtins), not `__builtins__`!
However, since we use [`DEFAULT_PROTOCOL=2`](https://github.com/pytorch/pytorch/blob/main/torch/serialization.py#L62), builtins will be pickled with their module string as `__builtin__`.
```python
>>> import pickle
>>> import pickletools
>>> print.__module__
'builtins'
>>> with open('print.pkl', 'wb') as f:
>>> pickle.dump(print, f, protocol=2) # 2 because this is the default protocol used by pytorch
>>> with open('print.pkl', 'rb') as f:
>>> pickletools.dis(f)
0: \x80 PROTO 2
2: c GLOBAL '__builtin__ print' # pickle saves the module string as __builtin__ !!! :(
21: q BINPUT 0
23: . STOP
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129244
Approved by: https://github.com/albanD
2024-06-25 01:08:26 +00:00
|
|
|
self.proto: int = -1
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
|
|
|
|
|
def load(self):
|
|
|
|
|
"""Read a pickled object representation from the open file.
|
|
|
|
|
|
|
|
|
|
Return the reconstituted object hierarchy specified in the file.
|
|
|
|
|
"""
|
|
|
|
|
self.metastack = []
|
2025-01-21 21:42:12 +00:00
|
|
|
self.stack: list[Any] = []
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
self.append = self.stack.append
|
|
|
|
|
read = self.read
|
|
|
|
|
while True:
|
|
|
|
|
key = read(1)
|
|
|
|
|
if not key:
|
|
|
|
|
raise EOFError
|
|
|
|
|
assert isinstance(key, bytes_types)
|
|
|
|
|
# Risky operators
|
|
|
|
|
if key[0] == GLOBAL[0]:
|
2024-11-01 16:50:25 +00:00
|
|
|
module, name = _read_global_instruction(self.readline)
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
full_path = f"{module}.{name}"
|
2024-07-22 16:06:57 +00:00
|
|
|
if module in _blocklisted_modules:
|
2024-08-15 19:48:35 +00:00
|
|
|
raise UnpicklingError(
|
2024-07-22 16:06:57 +00:00
|
|
|
f"Trying to load unsupported GLOBAL {full_path} whose module {module} is blocked."
|
|
|
|
|
)
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
if full_path in _get_allowed_globals():
|
|
|
|
|
self.append(_get_allowed_globals()[full_path])
|
Allow tensor subclasses and add `torch.serialization.add_safe_globals` that allows users to allowlist classes for `weights_only` load (#124331)
#### Conditions for allowlisting tensor subclasses
We allow tensor subclasses types that
(1) Do not override `__setstate__`, `__getattr__`, `__setattr__`, `__get__`, `__set__` or `__getattribute__` of `torch.Tensor` (`torch.Tensor` does not have a definition of `__getattr__`, `__get__` or `__set__` so we check that these are `None`)
(2) Use the generic `tp_alloc`
(3) Are in a module that *has been imported by the user*
to be pushed onto the stack as strings by `GLOBAL` instructions, while storing the type in a dict
The strings will be converted to the classes as appropriate when executing `REBUILD` with `_rebuild_from_type_v2`
*Note that we use `inspect.getattr_static(sys.modules[module], name)` to get the class/function as this method claims to have no code execution.
The rationale for the 3 conditions above is as follows:
The rebuild func provided by `Tensor.__reduce_ex__` is `torch._tensor._rebuild_from_type_v2`, which is defined as such (note the call to `getattr`, `Tensor.__setstate__` and the call to `as_subclass` as well as the call to `_set_obj_state` which calls `setattr`)
https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/_tensor.py#L57-L71
`as_subclass` is implemented with a call to `THPVariable_NewWithVar`
that will eventually call `tp_alloc` here
https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/csrc/autograd/python_variable.cpp#L2053
The `func` arg to `_rebuild_from_type_v2` for wrapper subclasses is `Tensor.rebuild_wrapper_subclass`, which will similarly call into `THPVariable_NewWithVar` and hit the above `tp_alloc`
**Note that we do not call `tp_init` or `tp_new` (i.e. `cls.__init__` or `cls.__new__`) when unpickling**
### How do we check something is a tensor subclass/constraints around imports
In order to check whether `bla` is a tensor subclass in the bytecode `GLOBAL module.name`, we need to do an `issubclass` check, which entails converting the global string to the appropriate type. We *do not* arbitrarily import modules but will perform this check as long as the given subclass (given by `module.name`) has already been imported by the user (i.e. `module in sys.modules` and `issubclass(getattr(sys[modules], name), torch.Tensor)`
This PR also allowlisted `torch._utils._rebuild_wrapper_subclass` and `torch.device` (used by `_rebuild_wrapper_subclass`)
### API for allow listing
This PR also added `torch.serialization.{add/get/clear}_safe_globals` that enables user to allowlist globals they have deemed safe and manipulate this list (for example they could allowlist a tensor subclass with a custom `__setstate__` if they have checked that this is safe).
Next steps:
- Add testing and allowlist required classes for all in-core tensor subclasses (e.g. `DTensor`, `FakeTensor` etc.)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124331
Approved by: https://github.com/albanD
2024-05-17 14:54:46 +00:00
|
|
|
elif full_path in _get_user_allowed_globals():
|
|
|
|
|
self.append(_get_user_allowed_globals()[full_path])
|
2024-11-18 22:03:08 +00:00
|
|
|
elif full_path in (
|
|
|
|
|
[
|
|
|
|
|
"torch.nested._internal.nested_tensor.NestedTensor",
|
|
|
|
|
"torch.nested._internal.nested_tensor._rebuild_njt",
|
|
|
|
|
"torch._dynamo.decorators._DimRange",
|
|
|
|
|
]
|
|
|
|
|
):
|
|
|
|
|
raise UnpicklingError(
|
|
|
|
|
"``torch.nested`` and ``torch._dynamo`` must be imported to load nested jagged tensors (NJTs)"
|
|
|
|
|
)
|
2024-11-18 22:03:09 +00:00
|
|
|
elif full_path in (
|
|
|
|
|
[
|
|
|
|
|
"torch.distributed.device_mesh.DeviceMesh",
|
|
|
|
|
"torch.distributed.tensor._dtensor_spec.DTensorSpec",
|
|
|
|
|
"torch.distributed.tensor._dtensor_spec.TensorMeta",
|
|
|
|
|
"torch.distributed.tensor.DTensor",
|
|
|
|
|
"torch.distributed.tensor.placement_types.Partial",
|
|
|
|
|
"torch.distributed.tensor.placement_types.Replicate",
|
|
|
|
|
"torch.distributed.tensor.placement_types.Shard",
|
|
|
|
|
]
|
|
|
|
|
):
|
|
|
|
|
raise UnpicklingError(
|
|
|
|
|
"``torch.distributed.tensor`` must be imported to load DTensors"
|
|
|
|
|
)
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
else:
|
2024-08-15 19:48:35 +00:00
|
|
|
raise UnpicklingError(
|
2024-06-04 20:55:14 +00:00
|
|
|
f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. "
|
2024-11-02 01:58:05 +00:00
|
|
|
f"Please use `torch.serialization.add_safe_globals([{name}])` or the "
|
|
|
|
|
f"`torch.serialization.safe_globals([{name}])` context manager to allowlist this global "
|
|
|
|
|
"if you trust this class/function."
|
2024-06-04 20:55:14 +00:00
|
|
|
)
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
elif key[0] == NEWOBJ[0]:
|
|
|
|
|
args = self.stack.pop()
|
|
|
|
|
cls = self.stack.pop()
|
2024-06-25 01:08:26 +00:00
|
|
|
if cls is torch.nn.Parameter:
|
|
|
|
|
self.append(torch.nn.Parameter(*args))
|
2024-11-07 22:50:15 +00:00
|
|
|
elif (
|
|
|
|
|
cls in _get_user_allowed_globals().values()
|
|
|
|
|
or cls in _get_allowed_globals().values()
|
|
|
|
|
):
|
2024-06-25 01:08:26 +00:00
|
|
|
self.append(cls.__new__(cls, *args))
|
|
|
|
|
else:
|
2024-08-15 19:48:35 +00:00
|
|
|
raise UnpicklingError(
|
Clarify error messages for NEWOBJ and BUILD in weights_only unpickler (#134346)
Clarify that `add_safe_globals` will allow types for these instructions
Some types do not appear as `GLOBAL` and are only caught in `BUILD`, example from hf slack is `numpy.dtypes.UInt32DType`
```python
import torch
import numpy as np
from tempfile import TemporaryDirectory
from pathlib import Path
from codecs import encode
torch.serialization.add_safe_globals([encode, np.dtype, np.core.multiarray._reconstruct, np.ndarray])
with TemporaryDirectory() as tempdir:
p = Path(tempdir)
r2 = np.random.get_state()
torch.save(r2, p / "r2.pkl")
torch.load(p / "r2.pkl", weights_only=True)
```
Yields (error comes from BUILD)
```
UnpicklingError: Weights only load failed. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
Please file an issue with the following so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler error: Can only build Tensor, parameter or OrderedDict objects, but got <class 'numpy.dtypes.UInt32DType'>
```
The reasoning is that `numpy.dtypes.UInt32DType` is constructed via `REDUCE` with `func =<class 'numpy.dtype'>` and `args= ('u4', False, True)`, clarify the error message that doing `add_safe_globals` on these will also allow them
After this PR error message becomes
```
_pickle.UnpicklingError: Weights only load failed. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
Please file an issue with the following so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`, but got <class 'numpy.dtypes.UInt32DType'>
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134346
Approved by: https://github.com/albanD
2024-08-26 21:28:18 +00:00
|
|
|
"Can only create new object for nn.Parameter or classes allowlisted "
|
|
|
|
|
f"via `add_safe_globals` but got {cls}"
|
2024-08-15 19:48:35 +00:00
|
|
|
)
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
elif key[0] == REDUCE[0]:
|
|
|
|
|
args = self.stack.pop()
|
|
|
|
|
func = self.stack[-1]
|
Allow tensor subclasses and add `torch.serialization.add_safe_globals` that allows users to allowlist classes for `weights_only` load (#124331)
#### Conditions for allowlisting tensor subclasses
We allow tensor subclasses types that
(1) Do not override `__setstate__`, `__getattr__`, `__setattr__`, `__get__`, `__set__` or `__getattribute__` of `torch.Tensor` (`torch.Tensor` does not have a definition of `__getattr__`, `__get__` or `__set__` so we check that these are `None`)
(2) Use the generic `tp_alloc`
(3) Are in a module that *has been imported by the user*
to be pushed onto the stack as strings by `GLOBAL` instructions, while storing the type in a dict
The strings will be converted to the classes as appropriate when executing `REBUILD` with `_rebuild_from_type_v2`
*Note that we use `inspect.getattr_static(sys.modules[module], name)` to get the class/function as this method claims to have no code execution.
The rationale for the 3 conditions above is as follows:
The rebuild func provided by `Tensor.__reduce_ex__` is `torch._tensor._rebuild_from_type_v2`, which is defined as such (note the call to `getattr`, `Tensor.__setstate__` and the call to `as_subclass` as well as the call to `_set_obj_state` which calls `setattr`)
https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/_tensor.py#L57-L71
`as_subclass` is implemented with a call to `THPVariable_NewWithVar`
that will eventually call `tp_alloc` here
https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/csrc/autograd/python_variable.cpp#L2053
The `func` arg to `_rebuild_from_type_v2` for wrapper subclasses is `Tensor.rebuild_wrapper_subclass`, which will similarly call into `THPVariable_NewWithVar` and hit the above `tp_alloc`
**Note that we do not call `tp_init` or `tp_new` (i.e. `cls.__init__` or `cls.__new__`) when unpickling**
### How do we check something is a tensor subclass/constraints around imports
In order to check whether `bla` is a tensor subclass in the bytecode `GLOBAL module.name`, we need to do an `issubclass` check, which entails converting the global string to the appropriate type. We *do not* arbitrarily import modules but will perform this check as long as the given subclass (given by `module.name`) has already been imported by the user (i.e. `module in sys.modules` and `issubclass(getattr(sys[modules], name), torch.Tensor)`
This PR also allowlisted `torch._utils._rebuild_wrapper_subclass` and `torch.device` (used by `_rebuild_wrapper_subclass`)
### API for allow listing
This PR also added `torch.serialization.{add/get/clear}_safe_globals` that enables user to allowlist globals they have deemed safe and manipulate this list (for example they could allowlist a tensor subclass with a custom `__setstate__` if they have checked that this is safe).
Next steps:
- Add testing and allowlist required classes for all in-core tensor subclasses (e.g. `DTensor`, `FakeTensor` etc.)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124331
Approved by: https://github.com/albanD
2024-05-17 14:54:46 +00:00
|
|
|
if (
|
|
|
|
|
func not in _get_allowed_globals().values()
|
|
|
|
|
and func not in _get_user_allowed_globals().values()
|
|
|
|
|
):
|
2024-08-15 19:48:35 +00:00
|
|
|
raise UnpicklingError(
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
f"Trying to call reduce for unrecognized function {func}"
|
|
|
|
|
)
|
|
|
|
|
self.stack[-1] = func(*args)
|
|
|
|
|
elif key[0] == BUILD[0]:
|
|
|
|
|
state = self.stack.pop()
|
|
|
|
|
inst = self.stack[-1]
|
|
|
|
|
if type(inst) is torch.Tensor:
|
|
|
|
|
# Legacy unpickling
|
|
|
|
|
inst.set_(*state)
|
|
|
|
|
elif type(inst) is torch.nn.Parameter:
|
|
|
|
|
inst.__setstate__(state)
|
|
|
|
|
elif type(inst) is OrderedDict:
|
|
|
|
|
inst.__dict__.update(state)
|
2024-11-07 22:50:15 +00:00
|
|
|
elif (
|
|
|
|
|
type(inst) in _get_user_allowed_globals().values()
|
|
|
|
|
or type(inst) in _get_allowed_globals().values()
|
|
|
|
|
):
|
2024-06-25 01:08:26 +00:00
|
|
|
if hasattr(inst, "__setstate__"):
|
|
|
|
|
inst.__setstate__(state)
|
Remove hasattr(__slots__) for BUILD logic in weights_only unpickler (#139541)
This is tested in PR stacked above in
```python
python test/distributed/fsdp/test_fsdp_state_dict.py TestFSDPStateDict.test_torch_save_load
```
We cannot depend on whether `hasattr(..., __slots__)` to know whether a BUILD instruction has slotstate. For example, if a class subclasses ABC `hasattr(__slots__)` will be `True` but there might be no slots (and hence `state` will not be a tuple). So revert #138936 to following the pickle library's code
```python
>>> from abc import ABC
>>> hasattr(ABC, "__slots__")
True
```
So
```python
import torch
from abc import ABC
from dataclasses import dataclass
class Foo(ABC):
pass
class FooWrapper(Foo):
def __init__(self, x, y):
self.x = x
self.y = y
f = FooWrapper(1, 2)
torch.save(f, "temp.pt")
with torch.serialization.safe_globals([FooWrapper]):
torch.load("temp.pt")
```
Would fail on the previous code with
```
File "/data/users/mg1998/pytorch/torch/serialization.py", line 1934, in _load
result = unpickler.load()
File "/data/users/mg1998/pytorch/torch/_weights_only_unpickler.py", line 366, in load
for k, v in slotstate.items():
```
As there is actually no slotstate
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139541
Approved by: https://github.com/malfet
ghstack dependencies: #138936, #139221, #139433
2024-11-02 01:58:04 +00:00
|
|
|
else:
|
|
|
|
|
# mimics load_build in pickle
|
|
|
|
|
# https://github.com/python/cpython/blob/f0c6fccd08904787a39269367f09f263d496114c/Lib/pickle.py#L1854-L1867
|
|
|
|
|
slotstate = None
|
|
|
|
|
if isinstance(state, tuple) and len(state) == 2:
|
|
|
|
|
state, slotstate = state
|
Fix weights_only for BUILD instructions for user allowlisted objects with __slots__ (#138936)
Previously `BUILD` instruction missed handling for `__slots__`. **This only applies for things allowlisted via `add_safe_globals`/`safe_globals` that use slots.**
### Background
When does pickle serialize a `BUILD` instruction? When `state` is not `None` and `state_setter` is `None` [[link](https://github.com/python/cpython/blob/c5b99f5c2c5347d66b9da362773969c531fb6c85/Lib/pickle.py#L765)]. In this case, the docs tell us that either `__setstate__` or a `__dict__` update will be performed [[link](https://github.com/python/cpython/blob/3.13/Lib/pickletools.py#L1984)]
`__reduce__`/`__reduce_ex__` are expected to return tuples of length 2 to 6 where `state` is the 3rd argument. When user doesn't patch `__reduce__` but patches `__setstate__`/`__getstate__`, state will be what is yielded by `__getstate__`
Note the return type for [`__getstate__` ](https://docs.python.org/3/library/pickle.html#object.__getstate__)
- For a class that has no instance [`__dict__`](https://docs.python.org/3/reference/datamodel.html#object.__dict__) and no [`__slots__`](https://docs.python.org/3/reference/datamodel.html#object.__slots__), the default state is None.
- For a class that has an instance [`__dict__`](https://docs.python.org/3/reference/datamodel.html#object.__dict__) and no [`__slots__`](https://docs.python.org/3/reference/datamodel.html#object.__slots__), the default state is `self.__dict__`.
- For a class that has an instance [`__dict__`](https://docs.python.org/3/reference/datamodel.html#object.__dict__) and [`__slots__`](https://docs.python.org/3/reference/datamodel.html#object.__slots__), the default state is a tuple consisting of two dictionaries: `self.__dict__`, and a dictionary mapping slot names to slot values. Only slots that have a value are included in the latter.
- For a class that has [`__slots__`](https://docs.python.org/3/reference/datamodel.html#object.__slots__) and no instance [`__dict__`](https://docs.python.org/3/reference/datamodel.html#object.__dict__), the default state is a tuple whose first item is None and whose second item is a dictionary mapping slot names to slot values described in the previous bullet.
see handling in pickle code https://github.com/python/cpython/blob/c5b99f5c2c5347d66b9da362773969c531fb6c85/Lib/pickle.py#L1846-L1867
Before this PR, we didn't account for the fact that when `__setstate__` is not defined, `state` might be a tuple so this would fail
```python
from dataclasses import dataclass
# Define the dataclass
@dataclass
class MyDataClass:
__slots__ = ["x", "y"]
x: int
y: str
# Create an instance of the dataclass
my_data = MyDataClass(x=2, y=3)
# Save the dataclass to a file
torch.save(my_data, "my_data.pt")
with torch.serialization.safe_globals([MyDataClass]):
loaded_my_data = torch.load("my_data.pt", weights_only=True)
# AttributeError: 'MyDataClass' object has no attribute '__dict__'
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138936
Approved by: https://github.com/malfet
2024-10-31 19:06:08 +00:00
|
|
|
if state:
|
|
|
|
|
inst.__dict__.update(state)
|
Remove hasattr(__slots__) for BUILD logic in weights_only unpickler (#139541)
This is tested in PR stacked above in
```python
python test/distributed/fsdp/test_fsdp_state_dict.py TestFSDPStateDict.test_torch_save_load
```
We cannot depend on whether `hasattr(..., __slots__)` to know whether a BUILD instruction has slotstate. For example, if a class subclasses ABC `hasattr(__slots__)` will be `True` but there might be no slots (and hence `state` will not be a tuple). So revert #138936 to following the pickle library's code
```python
>>> from abc import ABC
>>> hasattr(ABC, "__slots__")
True
```
So
```python
import torch
from abc import ABC
from dataclasses import dataclass
class Foo(ABC):
pass
class FooWrapper(Foo):
def __init__(self, x, y):
self.x = x
self.y = y
f = FooWrapper(1, 2)
torch.save(f, "temp.pt")
with torch.serialization.safe_globals([FooWrapper]):
torch.load("temp.pt")
```
Would fail on the previous code with
```
File "/data/users/mg1998/pytorch/torch/serialization.py", line 1934, in _load
result = unpickler.load()
File "/data/users/mg1998/pytorch/torch/_weights_only_unpickler.py", line 366, in load
for k, v in slotstate.items():
```
As there is actually no slotstate
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139541
Approved by: https://github.com/malfet
ghstack dependencies: #138936, #139221, #139433
2024-11-02 01:58:04 +00:00
|
|
|
if slotstate:
|
|
|
|
|
for k, v in slotstate.items():
|
|
|
|
|
setattr(inst, k, v)
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
else:
|
2024-08-15 19:48:35 +00:00
|
|
|
raise UnpicklingError(
|
Clarify error messages for NEWOBJ and BUILD in weights_only unpickler (#134346)
Clarify that `add_safe_globals` will allow types for these instructions
Some types do not appear as `GLOBAL` and are only caught in `BUILD`, example from hf slack is `numpy.dtypes.UInt32DType`
```python
import torch
import numpy as np
from tempfile import TemporaryDirectory
from pathlib import Path
from codecs import encode
torch.serialization.add_safe_globals([encode, np.dtype, np.core.multiarray._reconstruct, np.ndarray])
with TemporaryDirectory() as tempdir:
p = Path(tempdir)
r2 = np.random.get_state()
torch.save(r2, p / "r2.pkl")
torch.load(p / "r2.pkl", weights_only=True)
```
Yields (error comes from BUILD)
```
UnpicklingError: Weights only load failed. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
Please file an issue with the following so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler error: Can only build Tensor, parameter or OrderedDict objects, but got <class 'numpy.dtypes.UInt32DType'>
```
The reasoning is that `numpy.dtypes.UInt32DType` is constructed via `REDUCE` with `func =<class 'numpy.dtype'>` and `args= ('u4', False, True)`, clarify the error message that doing `add_safe_globals` on these will also allow them
After this PR error message becomes
```
_pickle.UnpicklingError: Weights only load failed. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
Please file an issue with the following so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`, but got <class 'numpy.dtypes.UInt32DType'>
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134346
Approved by: https://github.com/albanD
2024-08-26 21:28:18 +00:00
|
|
|
"Can only build Tensor, Parameter, OrderedDict or types allowlisted "
|
|
|
|
|
f"via `add_safe_globals`, but got {type(inst)}"
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
)
|
|
|
|
|
# Stack manipulation
|
|
|
|
|
elif key[0] == APPEND[0]:
|
|
|
|
|
item = self.stack.pop()
|
|
|
|
|
list_obj = self.stack[-1]
|
|
|
|
|
if type(list_obj) is not list:
|
2024-08-15 19:48:35 +00:00
|
|
|
raise UnpicklingError(
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
f"Can only append to lists, but got {type(list_obj)}"
|
|
|
|
|
)
|
|
|
|
|
list_obj.append(item)
|
|
|
|
|
elif key[0] == APPENDS[0]:
|
|
|
|
|
items = self.pop_mark()
|
|
|
|
|
list_obj = self.stack[-1]
|
|
|
|
|
if type(list_obj) is not list:
|
2024-08-15 19:48:35 +00:00
|
|
|
raise UnpicklingError(
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
f"Can only extend lists, but got {type(list_obj)}"
|
|
|
|
|
)
|
|
|
|
|
list_obj.extend(items)
|
|
|
|
|
elif key[0] == SETITEM[0]:
|
|
|
|
|
(v, k) = (self.stack.pop(), self.stack.pop())
|
|
|
|
|
self.stack[-1][k] = v
|
|
|
|
|
elif key[0] == SETITEMS[0]:
|
|
|
|
|
items = self.pop_mark()
|
|
|
|
|
for i in range(0, len(items), 2):
|
|
|
|
|
self.stack[-1][items[i]] = items[i + 1]
|
|
|
|
|
elif key[0] == MARK[0]:
|
|
|
|
|
self.metastack.append(self.stack)
|
|
|
|
|
self.stack = []
|
|
|
|
|
self.append = self.stack.append
|
|
|
|
|
elif key[0] == TUPLE[0]:
|
|
|
|
|
items = self.pop_mark()
|
|
|
|
|
self.append(tuple(items))
|
|
|
|
|
elif key[0] == TUPLE1[0]:
|
|
|
|
|
self.stack[-1] = (self.stack[-1],)
|
|
|
|
|
elif key[0] == TUPLE2[0]:
|
|
|
|
|
self.stack[-2:] = [(self.stack[-2], self.stack[-1])]
|
|
|
|
|
elif key[0] == TUPLE3[0]:
|
|
|
|
|
self.stack[-3:] = [(self.stack[-3], self.stack[-2], self.stack[-1])]
|
|
|
|
|
# Basic types construction
|
|
|
|
|
elif key[0] == NONE[0]:
|
|
|
|
|
self.append(None)
|
|
|
|
|
elif key[0] == NEWFALSE[0]:
|
|
|
|
|
self.append(False)
|
|
|
|
|
elif key[0] == NEWTRUE[0]:
|
|
|
|
|
self.append(True)
|
|
|
|
|
elif key[0] == EMPTY_TUPLE[0]:
|
|
|
|
|
self.append(())
|
|
|
|
|
elif key[0] == EMPTY_LIST[0]:
|
|
|
|
|
self.append([])
|
|
|
|
|
elif key[0] == EMPTY_DICT[0]:
|
|
|
|
|
self.append({})
|
|
|
|
|
elif key[0] == EMPTY_SET[0]:
|
|
|
|
|
self.append(set())
|
|
|
|
|
elif key[0] == BININT[0]:
|
|
|
|
|
self.append(unpack("<i", read(4))[0])
|
|
|
|
|
elif key[0] == BININT1[0]:
|
|
|
|
|
self.append(self.read(1)[0])
|
|
|
|
|
elif key[0] == BININT2[0]:
|
|
|
|
|
self.append(unpack("<H", read(2))[0])
|
2023-02-15 23:13:21 +00:00
|
|
|
elif key[0] == BINFLOAT[0]:
|
|
|
|
|
self.append(unpack(">d", self.read(8))[0])
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
elif key[0] == BINUNICODE[0]:
|
|
|
|
|
strlen = unpack("<I", read(4))[0]
|
|
|
|
|
if strlen > maxsize:
|
2024-08-15 19:48:35 +00:00
|
|
|
raise UnpicklingError("String is too long")
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
strval = str(read(strlen), "utf-8", "surrogatepass")
|
|
|
|
|
self.append(strval)
|
|
|
|
|
elif key[0] == SHORT_BINSTRING[0]:
|
|
|
|
|
strlen = read(1)[0]
|
|
|
|
|
strdata = read(strlen)
|
|
|
|
|
if self.encoding != "bytes":
|
|
|
|
|
strdata = strdata.decode(self.encoding, "strict")
|
|
|
|
|
self.append(strdata)
|
|
|
|
|
elif key[0] == BINPERSID[0]:
|
|
|
|
|
pid = self.stack.pop()
|
|
|
|
|
# Only allow persistent load of storage
|
|
|
|
|
if type(pid) is not tuple and not type(pid) is not int:
|
2024-08-15 19:48:35 +00:00
|
|
|
raise UnpicklingError(
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
f"persistent_load id must be tuple or int, but got {type(pid)}"
|
|
|
|
|
)
|
|
|
|
|
if (
|
|
|
|
|
type(pid) is tuple
|
|
|
|
|
and len(pid) > 0
|
|
|
|
|
and torch.serialization._maybe_decode_ascii(pid[0]) != "storage"
|
|
|
|
|
):
|
2024-08-15 19:48:35 +00:00
|
|
|
raise UnpicklingError(
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
f"Only persistent_load of storage is allowed, but got {pid[0]}"
|
|
|
|
|
)
|
|
|
|
|
self.append(self.persistent_load(pid))
|
|
|
|
|
elif key[0] in [BINGET[0], LONG_BINGET[0]]:
|
|
|
|
|
idx = (read(1) if key[0] == BINGET[0] else unpack("<I", read(4)))[0]
|
|
|
|
|
self.append(self.memo[idx])
|
|
|
|
|
elif key[0] in [BINPUT[0], LONG_BINPUT[0]]:
|
|
|
|
|
i = (read(1) if key[0] == BINPUT[0] else unpack("<I", read(4)))[0]
|
|
|
|
|
if i < 0:
|
|
|
|
|
raise ValueError("negative argument")
|
|
|
|
|
self.memo[i] = self.stack[-1]
|
|
|
|
|
elif key[0] == LONG1[0]:
|
|
|
|
|
n = read(1)[0]
|
|
|
|
|
data = read(n)
|
|
|
|
|
self.append(decode_long(data))
|
|
|
|
|
# First and last deserializer ops
|
|
|
|
|
elif key[0] == PROTO[0]:
|
Fix allowlisting of builtins for weights_only unpickler (#129244)
Since we use [`DEFAULT_PROTOCOL=2`](https://github.com/pytorch/pytorch/blob/main/torch/serialization.py#L62), some functions/classes that were renamed from python 2-->3 will be pickled with their python2 name. This PR ensures that when a mod `GLOBAL <python2_mod>.<python2_name> ` is encountered, [following the strategy used by pickle](https://github.com/python/cpython/blob/main/Lib/pickle.py#L1590C13-L1593C63) it is properly mapped to `<python3_mod>.<python3_name>`.
This fix ensures that `add_safe_globals` works properly for such functions/classes (i.e. users will allowlist the python3 func and the weights_only unpickler will do the appropriate translation when checking whether a class was allowlisted).
An example is as follows:
`__builtin__` was named to `builtins`, see the [release notes for Python 3.0](https://docs.python.org/3/whatsnew/3.0.html)
> Renamed module `__builtin__` to [`builtins`](https://docs.python.org/3/library/builtins.html#module-builtins) (removing the underscores, adding an ‘s’). The __builtins__ variable found in most global namespaces is unchanged. To modify a builtin, you should use [builtins](https://docs.python.org/3/library/builtins.html#module-builtins), not `__builtins__`!
However, since we use [`DEFAULT_PROTOCOL=2`](https://github.com/pytorch/pytorch/blob/main/torch/serialization.py#L62), builtins will be pickled with their module string as `__builtin__`.
```python
>>> import pickle
>>> import pickletools
>>> print.__module__
'builtins'
>>> with open('print.pkl', 'wb') as f:
>>> pickle.dump(print, f, protocol=2) # 2 because this is the default protocol used by pytorch
>>> with open('print.pkl', 'rb') as f:
>>> pickletools.dis(f)
0: \x80 PROTO 2
2: c GLOBAL '__builtin__ print' # pickle saves the module string as __builtin__ !!! :(
21: q BINPUT 0
23: . STOP
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129244
Approved by: https://github.com/albanD
2024-06-25 01:08:26 +00:00
|
|
|
self.proto = read(1)[0]
|
|
|
|
|
if self.proto != 2:
|
|
|
|
|
warnings.warn(
|
|
|
|
|
f"Detected pickle protocol {self.proto} in the checkpoint, which was "
|
|
|
|
|
"not the default pickle protocol used by `torch.load` (2). The weights_only "
|
|
|
|
|
"Unpickler might not support all instructions implemented by this protocol, "
|
|
|
|
|
"please file an issue for adding support if you encounter this."
|
|
|
|
|
)
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
elif key[0] == STOP[0]:
|
|
|
|
|
rc = self.stack.pop()
|
|
|
|
|
return rc
|
|
|
|
|
else:
|
2024-08-15 19:48:35 +00:00
|
|
|
raise UnpicklingError(f"Unsupported operand {key[0]}")
|
Add `weights_only` option to `torch.load` (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
|
|
|
|
|
|
|
|
# Return a list of items pushed in the stack after last MARK instruction.
|
|
|
|
|
def pop_mark(self):
|
|
|
|
|
items = self.stack
|
|
|
|
|
self.stack = self.metastack.pop()
|
|
|
|
|
self.append = self.stack.append
|
|
|
|
|
return items
|
|
|
|
|
|
|
|
|
|
def persistent_load(self, pid):
|
|
|
|
|
raise UnpicklingError("unsupported persistent id encountered")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load(file, *, encoding: str = "ASCII"):
|
|
|
|
|
return Unpickler(file, encoding=encoding).load()
|