Add test for pickle_module (#98373)

I.e. a regression test for https://github.com/pytorch/pytorch/issues/88438

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98373
Approved by: https://github.com/huydhn, https://github.com/kit1980
This commit is contained in:
Nikita Shulga 2023-04-05 13:05:05 +00:00 committed by PyTorch MergeBot
parent ea00f850e9
commit 3da7e83250

View file

@ -251,6 +251,25 @@ class SerializationMixin:
with self.assertRaisesRegex(ValueError, 'supports dill >='):
x2 = torch.load(f, pickle_module=dill, encoding='utf-8')
def test_pickle_module(self):
class ThrowingUnpickler(pickle.Unpickler):
def load(self, *args, **kwargs):
raise RuntimeError("rumpelstiltskin")
class ThrowingModule:
Unpickler = ThrowingUnpickler
load = ThrowingUnpickler.load
x = torch.eye(3)
with tempfile.NamedTemporaryFile() as f:
torch.save(x, f)
f.seek(0)
with self.assertRaisesRegex(RuntimeError, "rumpelstiltskin"):
torch.load(f, pickle_module=ThrowingModule)
f.seek(0)
z = torch.load(f)
self.assertEqual(x, z)
@unittest.skipIf(
not TEST_DILL or not HAS_DILL_AT_LEAST_0_3_1,
'"dill" not found or not correct version'