Allow BUILD/NEWOBJ instruction for items added via torch.serialization.add_safe_globals (#129251)

Previously, allowlisting functions/classes via `torch.serialization.add_safe_globals(obj)` for the `weights_only` Unpickler had the following effect:

- For a [`GLOBAL`](https://github.com/python/cpython/blob/3.12/Lib/pickletools.py#L1926-L1939) instruction, `GLOBAL obj.__module__ obj.__name__` would be allowed and translated back to obj to be pushed back to the stack.
- For a [`REDUCE`](https://github.com/python/cpython/blob/3.12/Lib/pickletools.py#L1926-L1982) instruction where we expect the stack to contain `func` and `args`, `func` is allowed if it was added via `add_safe_globals`

However, it did not have an effect on `BUILD` and `NEWOBJ` instructions

Some classes may be rebuilt via [`NEWOBJ`](https://github.com/python/cpython/blob/3.12/Lib/pickletools.py#L2091-L2104) instruction, which indicates that their constructor should be used to rebuild the class.

Further, a [`BUILD`](https://github.com/python/cpython/blob/3.12/Lib/pickletools.py#L1984-L2007) instruction might be used if an object's `__reduce__`/`__reduce_ex__` returns a non-None value for `state`. Which indicates a `__setstate__` or `__dict__.update`.

**This PR makes sure that adding objects to the allowlist will also allow `NEWOBJ` and `BUILD` instructions for them.**

In particular, the update for `NEWOBJ` should unblock allowlisting of [`ScaledMMConfig`](d4ade877df/float8_experimental/float8_tensor.py (L26-L30)) in float8_experimental @drisspg

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129251
Approved by: https://github.com/albanD
ghstack dependencies: #129244
This commit is contained in:
Mikayla Gawarecki 2024-06-24 18:08:26 -07:00 committed by PyTorch MergeBot
parent 1bb1e3463c
commit c5f7755e86
4 changed files with 85 additions and 5 deletions

View file

@ -536,6 +536,16 @@ class DTensorTest(DTensorTestBase):
buffer.seek(0)
reloaded_st = torch.load(buffer)
self.assertEqual(sharded_tensor, reloaded_st)
# Test weights_only load
try:
torch.serialization.add_safe_globals(
[DTensor, DeviceMesh, Shard, DTensorSpec, TensorMeta]
)
buffer.seek(0)
reloaded_st = torch.load(buffer, weights_only=True)
self.assertEqual(sharded_tensor, reloaded_st)
finally:
torch.serialization.clear_safe_globals()
class DTensorMeshTest(DTensorTestBase):

View file

@ -15,7 +15,7 @@ import pickle
import shutil
import pathlib
import platform
from collections import OrderedDict
from collections import namedtuple, OrderedDict
from copy import deepcopy
from itertools import product
@ -804,6 +804,17 @@ class serialization_method:
def __exit__(self, *args, **kwargs):
torch.save = self.torch_save
Point = namedtuple('Point', ['x', 'y'])
class ClassThatUsesBuildInstruction:
def __init__(self, num):
self.num = num
def __reduce_ex__(self, proto):
# Third item, state here will cause pickle to push a BUILD instruction
return ClassThatUsesBuildInstruction, (self.num,), {'foo': 'bar'}
@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
class TestBothSerialization(TestCase):
@parametrize("weights_only", (True, False))
@ -1049,6 +1060,55 @@ class TestSerialization(TestCase, SerializationMixin):
finally:
torch.serialization.clear_safe_globals()
def test_weights_only_safe_globals_newobj(self):
# This will use NEWOBJ
p = Point(x=1, y=2)
with BytesIOContext() as f:
torch.save(p, f)
f.seek(0)
with self.assertRaisesRegex(pickle.UnpicklingError,
"GLOBAL __main__.Point was not an allowed global by default"):
torch.load(f, weights_only=True)
f.seek(0)
try:
torch.serialization.add_safe_globals([Point])
loaded_p = torch.load(f, weights_only=True)
self.assertEqual(loaded_p, p)
finally:
torch.serialization.clear_safe_globals()
def test_weights_only_safe_globals_build(self):
counter = 0
def fake_set_state(obj, *args):
nonlocal counter
counter += 1
c = ClassThatUsesBuildInstruction(2)
with BytesIOContext() as f:
torch.save(c, f)
f.seek(0)
with self.assertRaisesRegex(pickle.UnpicklingError,
"GLOBAL __main__.ClassThatUsesBuildInstruction was not an allowed global by default"):
torch.load(f, weights_only=True)
try:
torch.serialization.add_safe_globals([ClassThatUsesBuildInstruction])
# Test dict update path
f.seek(0)
loaded_c = torch.load(f, weights_only=True)
self.assertEqual(loaded_c.num, 2)
self.assertEqual(loaded_c.foo, 'bar')
# Test setstate path
ClassThatUsesBuildInstruction.__setstate__ = fake_set_state
f.seek(0)
loaded_c = torch.load(f, weights_only=True)
self.assertEqual(loaded_c.num, 2)
self.assertEqual(counter, 1)
self.assertFalse(hasattr(loaded_c, 'foo'))
finally:
torch.serialization.clear_safe_globals()
ClassThatUsesBuildInstruction.__setstate__ = None
@parametrize('weights_only', (False, True))
def test_serialization_math_bits(self, weights_only):
t = torch.randn(1, dtype=torch.cfloat)

View file

@ -231,9 +231,12 @@ class Unpickler:
elif key[0] == NEWOBJ[0]:
args = self.stack.pop()
cls = self.stack.pop()
if cls is not torch.nn.Parameter:
if cls is torch.nn.Parameter:
self.append(torch.nn.Parameter(*args))
elif cls in _get_user_allowed_globals().values():
self.append(cls.__new__(cls, *args))
else:
raise RuntimeError(f"Trying to instantiate unsupported class {cls}")
self.append(torch.nn.Parameter(*args))
elif key[0] == REDUCE[0]:
args = self.stack.pop()
func = self.stack[-1]
@ -255,9 +258,14 @@ class Unpickler:
inst.__setstate__(state)
elif type(inst) is OrderedDict:
inst.__dict__.update(state)
elif type(inst) in _get_user_allowed_globals().values():
if hasattr(inst, "__setstate__"):
inst.__setstate__(state)
else:
inst.__dict__.update(state)
else:
raise RuntimeError(
f"Can only build Tensor, parameter or dict objects, but got {type(inst)}"
f"Can only build Tensor, parameter or OrderedDict objects, but got {type(inst)}"
)
# Stack manipulation
elif key[0] == APPEND[0]:

View file

@ -203,7 +203,9 @@ def get_safe_globals() -> List[Any]:
def add_safe_globals(safe_globals: List[Any]) -> None:
"""
Marks the given globals as safe for ``weights_only`` load.
Marks the given globals as safe for ``weights_only`` load. For example, functions
added to this list can be called during unpickling, classes could be instantiated
and have state set.
Args:
safe_globals (List[Any]): list of globals to mark as safe