diff --git a/test/test_serialization.py b/test/test_serialization.py index 59d6e21bd3a..4e1a1a8e49d 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -16,6 +16,7 @@ import warnings import zipfile from collections import namedtuple, OrderedDict from copy import deepcopy +from dataclasses import dataclass from itertools import product from pathlib import Path @@ -844,6 +845,17 @@ class ClassThatUsesBuildInstruction: # Third item, state here will cause pickle to push a BUILD instruction return ClassThatUsesBuildInstruction, (self.num,), {'foo': 'bar'} +@dataclass +class ClassThatUsesBuildInstructionAllSlots: + __slots__ = ["x", "y"] + x: int + y: int + +@dataclass +class ClassThatUsesBuildInstructionSomeSlots(ClassThatUsesBuildInstructionAllSlots): + x: int + y: int + c: str @unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows") class TestBothSerialization(TestCase): @@ -1142,6 +1154,25 @@ class TestSerialization(TestCase, SerializationMixin): torch.serialization.clear_safe_globals() ClassThatUsesBuildInstruction.__setstate__ = None + @parametrize("slots", ['some', 'all']) + def test_weights_only_safe_globals_build_with_slots(self, slots): + obj_cls = ( + ClassThatUsesBuildInstructionAllSlots if slots == 'all' else ClassThatUsesBuildInstructionSomeSlots + ) + args = (2, 3) if slots == 'all' else (2, 3, 'foo') + obj = obj_cls(*args) + with BytesIOContext() as f: + torch.save(obj, f) + f.seek(0) + with self.assertRaisesRegex(pickle.UnpicklingError, + f"GLOBAL __main__.{obj_cls.__name__} was not an allowed global by default"): + torch.load(f, weights_only=True) + + f.seek(0) + with torch.serialization.safe_globals([obj_cls]): + loaded_obj = torch.load(f, weights_only=True) + self.assertEqual(loaded_obj, obj) + def test_weights_only_safe_globals_blocklist(self): module = 'nt' if IS_WINDOWS else 'posix' error_msg = f"unsupported GLOBAL {module}.execv whose module {module} is blocked" diff --git a/torch/_weights_only_unpickler.py b/torch/_weights_only_unpickler.py index 918db8ba0be..eb6072971ec 100644 --- a/torch/_weights_only_unpickler.py +++ b/torch/_weights_only_unpickler.py @@ -292,6 +292,13 @@ class Unpickler: elif type(inst) in _get_user_allowed_globals().values(): if hasattr(inst, "__setstate__"): inst.__setstate__(state) + elif hasattr(inst, "__slots__"): + # if slots are defined, state will be a tuple (state, slotstate) + state, slotstate = state + for k, v in slotstate.items(): + setattr(inst, k, v) + if state: + inst.__dict__.update(state) else: inst.__dict__.update(state) else: