Fix weights_only for BUILD instructions for user allowlisted objects with __slots__ (#138936)

Previously `BUILD` instruction missed handling for `__slots__`. **This only applies for things allowlisted via `add_safe_globals`/`safe_globals` that use slots.**

### Background
When does pickle serialize a `BUILD` instruction? When `state` is not `None` and `state_setter` is `None` [[link](c5b99f5c2c/Lib/pickle.py (L765))]. In this case, the docs tell us that either `__setstate__` or a `__dict__` update will be performed [[link](https://github.com/python/cpython/blob/3.13/Lib/pickletools.py#L1984)]

`__reduce__`/`__reduce_ex__` are expected to return tuples of length 2 to 6 where `state` is the 3rd argument. When user doesn't patch `__reduce__` but patches `__setstate__`/`__getstate__`, state will be what is yielded by `__getstate__`

Note the return type for [`__getstate__` ](https://docs.python.org/3/library/pickle.html#object.__getstate__)

- For a class that has no instance [`__dict__`](https://docs.python.org/3/reference/datamodel.html#object.__dict__) and no [`__slots__`](https://docs.python.org/3/reference/datamodel.html#object.__slots__), the default state is None.
- For a class that has an instance [`__dict__`](https://docs.python.org/3/reference/datamodel.html#object.__dict__) and no [`__slots__`](https://docs.python.org/3/reference/datamodel.html#object.__slots__), the default state is `self.__dict__`.
- For a class that has an instance [`__dict__`](https://docs.python.org/3/reference/datamodel.html#object.__dict__) and [`__slots__`](https://docs.python.org/3/reference/datamodel.html#object.__slots__), the default state is a tuple consisting of two dictionaries: `self.__dict__`, and a dictionary mapping slot names to slot values. Only slots that have a value are included in the latter.
- For a class that has [`__slots__`](https://docs.python.org/3/reference/datamodel.html#object.__slots__) and no instance [`__dict__`](https://docs.python.org/3/reference/datamodel.html#object.__dict__), the default state is a tuple whose first item is None and whose second item is a dictionary mapping slot names to slot values described in the previous bullet.

see handling in pickle code c5b99f5c2c/Lib/pickle.py (L1846-L1867)

Before this PR, we didn't account for the fact that when `__setstate__` is not defined, `state` might be a tuple so this would fail

```python
from dataclasses import dataclass

# Define the dataclass
@dataclass
class MyDataClass:
    __slots__ = ["x", "y"]
    x: int
    y: str
# Create an instance of the dataclass
my_data = MyDataClass(x=2, y=3)
# Save the dataclass to a file
torch.save(my_data, "my_data.pt")
with torch.serialization.safe_globals([MyDataClass]):
    loaded_my_data = torch.load("my_data.pt", weights_only=True)
# AttributeError: 'MyDataClass' object has no attribute '__dict__'
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138936
Approved by: https://github.com/malfet
This commit is contained in:
Mikayla Gawarecki 2024-10-31 12:06:08 -07:00 committed by PyTorch MergeBot
parent c2ffd41a86
commit 2a309c0997
2 changed files with 38 additions and 0 deletions

View file

@ -16,6 +16,7 @@ import warnings
import zipfile import zipfile
from collections import namedtuple, OrderedDict from collections import namedtuple, OrderedDict
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass
from itertools import product from itertools import product
from pathlib import Path from pathlib import Path
@ -844,6 +845,17 @@ class ClassThatUsesBuildInstruction:
# Third item, state here will cause pickle to push a BUILD instruction # Third item, state here will cause pickle to push a BUILD instruction
return ClassThatUsesBuildInstruction, (self.num,), {'foo': 'bar'} return ClassThatUsesBuildInstruction, (self.num,), {'foo': 'bar'}
@dataclass
class ClassThatUsesBuildInstructionAllSlots:
__slots__ = ["x", "y"]
x: int
y: int
@dataclass
class ClassThatUsesBuildInstructionSomeSlots(ClassThatUsesBuildInstructionAllSlots):
x: int
y: int
c: str
@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows") @unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
class TestBothSerialization(TestCase): class TestBothSerialization(TestCase):
@ -1142,6 +1154,25 @@ class TestSerialization(TestCase, SerializationMixin):
torch.serialization.clear_safe_globals() torch.serialization.clear_safe_globals()
ClassThatUsesBuildInstruction.__setstate__ = None ClassThatUsesBuildInstruction.__setstate__ = None
@parametrize("slots", ['some', 'all'])
def test_weights_only_safe_globals_build_with_slots(self, slots):
obj_cls = (
ClassThatUsesBuildInstructionAllSlots if slots == 'all' else ClassThatUsesBuildInstructionSomeSlots
)
args = (2, 3) if slots == 'all' else (2, 3, 'foo')
obj = obj_cls(*args)
with BytesIOContext() as f:
torch.save(obj, f)
f.seek(0)
with self.assertRaisesRegex(pickle.UnpicklingError,
f"GLOBAL __main__.{obj_cls.__name__} was not an allowed global by default"):
torch.load(f, weights_only=True)
f.seek(0)
with torch.serialization.safe_globals([obj_cls]):
loaded_obj = torch.load(f, weights_only=True)
self.assertEqual(loaded_obj, obj)
def test_weights_only_safe_globals_blocklist(self): def test_weights_only_safe_globals_blocklist(self):
module = 'nt' if IS_WINDOWS else 'posix' module = 'nt' if IS_WINDOWS else 'posix'
error_msg = f"unsupported GLOBAL {module}.execv whose module {module} is blocked" error_msg = f"unsupported GLOBAL {module}.execv whose module {module} is blocked"

View file

@ -292,6 +292,13 @@ class Unpickler:
elif type(inst) in _get_user_allowed_globals().values(): elif type(inst) in _get_user_allowed_globals().values():
if hasattr(inst, "__setstate__"): if hasattr(inst, "__setstate__"):
inst.__setstate__(state) inst.__setstate__(state)
elif hasattr(inst, "__slots__"):
# if slots are defined, state will be a tuple (state, slotstate)
state, slotstate = state
for k, v in slotstate.items():
setattr(inst, k, v)
if state:
inst.__dict__.update(state)
else: else:
inst.__dict__.update(state) inst.__dict__.update(state)
else: else: