mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Allow BUILD/NEWOBJ instruction for items added via torch.serialization.add_safe_globals (#129251)
Previously, allowlisting functions/classes via `torch.serialization.add_safe_globals(obj)` for the `weights_only` Unpickler had the following effect:
- For a [`GLOBAL`](https://github.com/python/cpython/blob/3.12/Lib/pickletools.py#L1926-L1939) instruction, `GLOBAL obj.__module__ obj.__name__` would be allowed and translated back to obj to be pushed back to the stack.
- For a [`REDUCE`](https://github.com/python/cpython/blob/3.12/Lib/pickletools.py#L1926-L1982) instruction where we expect the stack to contain `func` and `args`, `func` is allowed if it was added via `add_safe_globals`
However, it did not have an effect on `BUILD` and `NEWOBJ` instructions
Some classes may be rebuilt via [`NEWOBJ`](https://github.com/python/cpython/blob/3.12/Lib/pickletools.py#L2091-L2104) instruction, which indicates that their constructor should be used to rebuild the class.
Further, a [`BUILD`](https://github.com/python/cpython/blob/3.12/Lib/pickletools.py#L1984-L2007) instruction might be used if an object's `__reduce__`/`__reduce_ex__` returns a non-None value for `state`. Which indicates a `__setstate__` or `__dict__.update`.
**This PR makes sure that adding objects to the allowlist will also allow `NEWOBJ` and `BUILD` instructions for them.**
In particular, the update for `NEWOBJ` should unblock allowlisting of [`ScaledMMConfig`](d4ade877df/float8_experimental/float8_tensor.py (L26-L30)) in float8_experimental @drisspg
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129251
Approved by: https://github.com/albanD
ghstack dependencies: #129244
This commit is contained in:
parent
1bb1e3463c
commit
c5f7755e86
4 changed files with 85 additions and 5 deletions
|
|
@ -536,6 +536,16 @@ class DTensorTest(DTensorTestBase):
|
|||
buffer.seek(0)
|
||||
reloaded_st = torch.load(buffer)
|
||||
self.assertEqual(sharded_tensor, reloaded_st)
|
||||
# Test weights_only load
|
||||
try:
|
||||
torch.serialization.add_safe_globals(
|
||||
[DTensor, DeviceMesh, Shard, DTensorSpec, TensorMeta]
|
||||
)
|
||||
buffer.seek(0)
|
||||
reloaded_st = torch.load(buffer, weights_only=True)
|
||||
self.assertEqual(sharded_tensor, reloaded_st)
|
||||
finally:
|
||||
torch.serialization.clear_safe_globals()
|
||||
|
||||
|
||||
class DTensorMeshTest(DTensorTestBase):
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import pickle
|
|||
import shutil
|
||||
import pathlib
|
||||
import platform
|
||||
from collections import OrderedDict
|
||||
from collections import namedtuple, OrderedDict
|
||||
from copy import deepcopy
|
||||
from itertools import product
|
||||
|
||||
|
|
@ -804,6 +804,17 @@ class serialization_method:
|
|||
def __exit__(self, *args, **kwargs):
|
||||
torch.save = self.torch_save
|
||||
|
||||
Point = namedtuple('Point', ['x', 'y'])
|
||||
|
||||
class ClassThatUsesBuildInstruction:
|
||||
def __init__(self, num):
|
||||
self.num = num
|
||||
|
||||
def __reduce_ex__(self, proto):
|
||||
# Third item, state here will cause pickle to push a BUILD instruction
|
||||
return ClassThatUsesBuildInstruction, (self.num,), {'foo': 'bar'}
|
||||
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
|
||||
class TestBothSerialization(TestCase):
|
||||
@parametrize("weights_only", (True, False))
|
||||
|
|
@ -1049,6 +1060,55 @@ class TestSerialization(TestCase, SerializationMixin):
|
|||
finally:
|
||||
torch.serialization.clear_safe_globals()
|
||||
|
||||
def test_weights_only_safe_globals_newobj(self):
|
||||
# This will use NEWOBJ
|
||||
p = Point(x=1, y=2)
|
||||
with BytesIOContext() as f:
|
||||
torch.save(p, f)
|
||||
f.seek(0)
|
||||
with self.assertRaisesRegex(pickle.UnpicklingError,
|
||||
"GLOBAL __main__.Point was not an allowed global by default"):
|
||||
torch.load(f, weights_only=True)
|
||||
f.seek(0)
|
||||
try:
|
||||
torch.serialization.add_safe_globals([Point])
|
||||
loaded_p = torch.load(f, weights_only=True)
|
||||
self.assertEqual(loaded_p, p)
|
||||
finally:
|
||||
torch.serialization.clear_safe_globals()
|
||||
|
||||
def test_weights_only_safe_globals_build(self):
|
||||
counter = 0
|
||||
|
||||
def fake_set_state(obj, *args):
|
||||
nonlocal counter
|
||||
counter += 1
|
||||
|
||||
c = ClassThatUsesBuildInstruction(2)
|
||||
with BytesIOContext() as f:
|
||||
torch.save(c, f)
|
||||
f.seek(0)
|
||||
with self.assertRaisesRegex(pickle.UnpicklingError,
|
||||
"GLOBAL __main__.ClassThatUsesBuildInstruction was not an allowed global by default"):
|
||||
torch.load(f, weights_only=True)
|
||||
try:
|
||||
torch.serialization.add_safe_globals([ClassThatUsesBuildInstruction])
|
||||
# Test dict update path
|
||||
f.seek(0)
|
||||
loaded_c = torch.load(f, weights_only=True)
|
||||
self.assertEqual(loaded_c.num, 2)
|
||||
self.assertEqual(loaded_c.foo, 'bar')
|
||||
# Test setstate path
|
||||
ClassThatUsesBuildInstruction.__setstate__ = fake_set_state
|
||||
f.seek(0)
|
||||
loaded_c = torch.load(f, weights_only=True)
|
||||
self.assertEqual(loaded_c.num, 2)
|
||||
self.assertEqual(counter, 1)
|
||||
self.assertFalse(hasattr(loaded_c, 'foo'))
|
||||
finally:
|
||||
torch.serialization.clear_safe_globals()
|
||||
ClassThatUsesBuildInstruction.__setstate__ = None
|
||||
|
||||
@parametrize('weights_only', (False, True))
|
||||
def test_serialization_math_bits(self, weights_only):
|
||||
t = torch.randn(1, dtype=torch.cfloat)
|
||||
|
|
|
|||
|
|
@ -231,9 +231,12 @@ class Unpickler:
|
|||
elif key[0] == NEWOBJ[0]:
|
||||
args = self.stack.pop()
|
||||
cls = self.stack.pop()
|
||||
if cls is not torch.nn.Parameter:
|
||||
if cls is torch.nn.Parameter:
|
||||
self.append(torch.nn.Parameter(*args))
|
||||
elif cls in _get_user_allowed_globals().values():
|
||||
self.append(cls.__new__(cls, *args))
|
||||
else:
|
||||
raise RuntimeError(f"Trying to instantiate unsupported class {cls}")
|
||||
self.append(torch.nn.Parameter(*args))
|
||||
elif key[0] == REDUCE[0]:
|
||||
args = self.stack.pop()
|
||||
func = self.stack[-1]
|
||||
|
|
@ -255,9 +258,14 @@ class Unpickler:
|
|||
inst.__setstate__(state)
|
||||
elif type(inst) is OrderedDict:
|
||||
inst.__dict__.update(state)
|
||||
elif type(inst) in _get_user_allowed_globals().values():
|
||||
if hasattr(inst, "__setstate__"):
|
||||
inst.__setstate__(state)
|
||||
else:
|
||||
inst.__dict__.update(state)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Can only build Tensor, parameter or dict objects, but got {type(inst)}"
|
||||
f"Can only build Tensor, parameter or OrderedDict objects, but got {type(inst)}"
|
||||
)
|
||||
# Stack manipulation
|
||||
elif key[0] == APPEND[0]:
|
||||
|
|
|
|||
|
|
@ -203,7 +203,9 @@ def get_safe_globals() -> List[Any]:
|
|||
|
||||
def add_safe_globals(safe_globals: List[Any]) -> None:
|
||||
"""
|
||||
Marks the given globals as safe for ``weights_only`` load.
|
||||
Marks the given globals as safe for ``weights_only`` load. For example, functions
|
||||
added to this list can be called during unpickling, classes could be instantiated
|
||||
and have state set.
|
||||
|
||||
Args:
|
||||
safe_globals (List[Any]): list of globals to mark as safe
|
||||
|
|
|
|||
Loading…
Reference in a new issue