diff --git a/test/distributed/_tensor/test_dtensor.py b/test/distributed/_tensor/test_dtensor.py index 17a6aebd8f9..e2758e3216d 100644 --- a/test/distributed/_tensor/test_dtensor.py +++ b/test/distributed/_tensor/test_dtensor.py @@ -536,6 +536,16 @@ class DTensorTest(DTensorTestBase): buffer.seek(0) reloaded_st = torch.load(buffer) self.assertEqual(sharded_tensor, reloaded_st) + # Test weights_only load + try: + torch.serialization.add_safe_globals( + [DTensor, DeviceMesh, Shard, DTensorSpec, TensorMeta] + ) + buffer.seek(0) + reloaded_st = torch.load(buffer, weights_only=True) + self.assertEqual(sharded_tensor, reloaded_st) + finally: + torch.serialization.clear_safe_globals() class DTensorMeshTest(DTensorTestBase): diff --git a/test/test_serialization.py b/test/test_serialization.py index 9890d9b2d9f..a25f9856271 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -15,7 +15,7 @@ import pickle import shutil import pathlib import platform -from collections import OrderedDict +from collections import namedtuple, OrderedDict from copy import deepcopy from itertools import product @@ -804,6 +804,17 @@ class serialization_method: def __exit__(self, *args, **kwargs): torch.save = self.torch_save +Point = namedtuple('Point', ['x', 'y']) + +class ClassThatUsesBuildInstruction: + def __init__(self, num): + self.num = num + + def __reduce_ex__(self, proto): + # Third item, state here will cause pickle to push a BUILD instruction + return ClassThatUsesBuildInstruction, (self.num,), {'foo': 'bar'} + + @unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows") class TestBothSerialization(TestCase): @parametrize("weights_only", (True, False)) @@ -1049,6 +1060,55 @@ class TestSerialization(TestCase, SerializationMixin): finally: torch.serialization.clear_safe_globals() + def test_weights_only_safe_globals_newobj(self): + # This will use NEWOBJ + p = Point(x=1, y=2) + with BytesIOContext() as f: + torch.save(p, f) + f.seek(0) + with self.assertRaisesRegex(pickle.UnpicklingError, + "GLOBAL __main__.Point was not an allowed global by default"): + torch.load(f, weights_only=True) + f.seek(0) + try: + torch.serialization.add_safe_globals([Point]) + loaded_p = torch.load(f, weights_only=True) + self.assertEqual(loaded_p, p) + finally: + torch.serialization.clear_safe_globals() + + def test_weights_only_safe_globals_build(self): + counter = 0 + + def fake_set_state(obj, *args): + nonlocal counter + counter += 1 + + c = ClassThatUsesBuildInstruction(2) + with BytesIOContext() as f: + torch.save(c, f) + f.seek(0) + with self.assertRaisesRegex(pickle.UnpicklingError, + "GLOBAL __main__.ClassThatUsesBuildInstruction was not an allowed global by default"): + torch.load(f, weights_only=True) + try: + torch.serialization.add_safe_globals([ClassThatUsesBuildInstruction]) + # Test dict update path + f.seek(0) + loaded_c = torch.load(f, weights_only=True) + self.assertEqual(loaded_c.num, 2) + self.assertEqual(loaded_c.foo, 'bar') + # Test setstate path + ClassThatUsesBuildInstruction.__setstate__ = fake_set_state + f.seek(0) + loaded_c = torch.load(f, weights_only=True) + self.assertEqual(loaded_c.num, 2) + self.assertEqual(counter, 1) + self.assertFalse(hasattr(loaded_c, 'foo')) + finally: + torch.serialization.clear_safe_globals() + ClassThatUsesBuildInstruction.__setstate__ = None + @parametrize('weights_only', (False, True)) def test_serialization_math_bits(self, weights_only): t = torch.randn(1, dtype=torch.cfloat) diff --git a/torch/_weights_only_unpickler.py b/torch/_weights_only_unpickler.py index cf1514dcf41..ba131d20478 100644 --- a/torch/_weights_only_unpickler.py +++ b/torch/_weights_only_unpickler.py @@ -231,9 +231,12 @@ class Unpickler: elif key[0] == NEWOBJ[0]: args = self.stack.pop() cls = self.stack.pop() - if cls is not torch.nn.Parameter: + if cls is torch.nn.Parameter: + self.append(torch.nn.Parameter(*args)) + elif cls in _get_user_allowed_globals().values(): + self.append(cls.__new__(cls, *args)) + else: raise RuntimeError(f"Trying to instantiate unsupported class {cls}") - self.append(torch.nn.Parameter(*args)) elif key[0] == REDUCE[0]: args = self.stack.pop() func = self.stack[-1] @@ -255,9 +258,14 @@ class Unpickler: inst.__setstate__(state) elif type(inst) is OrderedDict: inst.__dict__.update(state) + elif type(inst) in _get_user_allowed_globals().values(): + if hasattr(inst, "__setstate__"): + inst.__setstate__(state) + else: + inst.__dict__.update(state) else: raise RuntimeError( - f"Can only build Tensor, parameter or dict objects, but got {type(inst)}" + f"Can only build Tensor, parameter or OrderedDict objects, but got {type(inst)}" ) # Stack manipulation elif key[0] == APPEND[0]: diff --git a/torch/serialization.py b/torch/serialization.py index 95d8d2e5cc6..d7c3fd15933 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -203,7 +203,9 @@ def get_safe_globals() -> List[Any]: def add_safe_globals(safe_globals: List[Any]) -> None: """ - Marks the given globals as safe for ``weights_only`` load. + Marks the given globals as safe for ``weights_only`` load. For example, functions + added to this list can be called during unpickling, classes could be instantiated + and have state set. Args: safe_globals (List[Any]): list of globals to mark as safe