mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add test for pickle_module (#98373)
I.e. a regression test for https://github.com/pytorch/pytorch/issues/88438 Pull Request resolved: https://github.com/pytorch/pytorch/pull/98373 Approved by: https://github.com/huydhn, https://github.com/kit1980
This commit is contained in:
parent
ea00f850e9
commit
3da7e83250
1 changed files with 19 additions and 0 deletions
|
|
@ -251,6 +251,25 @@ class SerializationMixin:
|
|||
with self.assertRaisesRegex(ValueError, 'supports dill >='):
|
||||
x2 = torch.load(f, pickle_module=dill, encoding='utf-8')
|
||||
|
||||
def test_pickle_module(self):
|
||||
class ThrowingUnpickler(pickle.Unpickler):
|
||||
def load(self, *args, **kwargs):
|
||||
raise RuntimeError("rumpelstiltskin")
|
||||
|
||||
class ThrowingModule:
|
||||
Unpickler = ThrowingUnpickler
|
||||
load = ThrowingUnpickler.load
|
||||
|
||||
x = torch.eye(3)
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
torch.save(x, f)
|
||||
f.seek(0)
|
||||
with self.assertRaisesRegex(RuntimeError, "rumpelstiltskin"):
|
||||
torch.load(f, pickle_module=ThrowingModule)
|
||||
f.seek(0)
|
||||
z = torch.load(f)
|
||||
self.assertEqual(x, z)
|
||||
|
||||
@unittest.skipIf(
|
||||
not TEST_DILL or not HAS_DILL_AT_LEAST_0_3_1,
|
||||
'"dill" not found or not correct version'
|
||||
|
|
|
|||
Loading…
Reference in a new issue