From 3da7e8325010a67ebf8ae1d6c21f15ada2d30d13 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Wed, 5 Apr 2023 13:05:05 +0000 Subject: [PATCH] Add test for pickle_module (#98373) I.e. a regression test for https://github.com/pytorch/pytorch/issues/88438 Pull Request resolved: https://github.com/pytorch/pytorch/pull/98373 Approved by: https://github.com/huydhn, https://github.com/kit1980 --- test/test_serialization.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/test/test_serialization.py b/test/test_serialization.py index 9b9a71334ba..a640eb16cc1 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -251,6 +251,25 @@ class SerializationMixin: with self.assertRaisesRegex(ValueError, 'supports dill >='): x2 = torch.load(f, pickle_module=dill, encoding='utf-8') + def test_pickle_module(self): + class ThrowingUnpickler(pickle.Unpickler): + def load(self, *args, **kwargs): + raise RuntimeError("rumpelstiltskin") + + class ThrowingModule: + Unpickler = ThrowingUnpickler + load = ThrowingUnpickler.load + + x = torch.eye(3) + with tempfile.NamedTemporaryFile() as f: + torch.save(x, f) + f.seek(0) + with self.assertRaisesRegex(RuntimeError, "rumpelstiltskin"): + torch.load(f, pickle_module=ThrowingModule) + f.seek(0) + z = torch.load(f) + self.assertEqual(x, z) + @unittest.skipIf( not TEST_DILL or not HAS_DILL_AT_LEAST_0_3_1, '"dill" not found or not correct version'