pytorch/test/jit/test_save_load.py
Sam Estep e3900d2ba5 Add lint for unqualified noqa (#56272)
Summary:
As this diff shows, currently there are a couple hundred instances of raw `noqa` in the codebase, which just ignore all errors on a given line. That isn't great, so this PR changes all existing instances of that antipattern to qualify the `noqa` with respect to a specific error code, and adds a lint to prevent more of this from happening in the future.

Interestingly, some of the examples the `noqa` lint catches are genuine attempts to qualify the `noqa` with a specific error code, such as these two:
```
test/jit/test_misc.py:27:            print(f"{hello + ' ' + test}, I'm a {test}") # noqa E999
test/jit/test_misc.py:28:            print(f"format blank") # noqa F541
```
However, those are still wrong because they are [missing a colon](https://flake8.pycqa.org/en/3.9.1/user/violations.html#in-line-ignoring-errors), which actually causes the error code to be completely ignored:

- If you change them to anything else, the warnings will still be suppressed.
- If you add the necessary colons then it is revealed that `E261` was also being suppressed, unintentionally:
  ```
  test/jit/test_misc.py:27:57: E261 at least two spaces before inline comment
  test/jit/test_misc.py:28:35: E261 at least two spaces before inline comment
  ```

I did try using [flake8-noqa](https://pypi.org/project/flake8-noqa/) instead of a custom `git grep` lint, but it didn't seem to work. This PR is definitely missing some of the functionality that flake8-noqa is supposed to provide, though, so if someone can figure out how to use it, we should do that instead.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/56272

Test Plan:
CI should pass on the tip of this PR, and we know that the lint works because the following CI run (before this PR was finished) failed:

- https://github.com/pytorch/pytorch/runs/2365189927

Reviewed By: janeyx99

Differential Revision: D27830127

Pulled By: samestep

fbshipit-source-id: d6dcf4f945ebd18cd76c46a07f3b408296864fcb
2021-04-19 13:16:18 -07:00

1022 lines
36 KiB
Python

from itertools import product as product
from typing import NamedTuple, Optional
import io
import os
import pathlib
import random
import sys
from torch import Tensor
from torch.testing._internal.common_utils import TemporaryFileName
import torch
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import (JitTestCase,
clear_class_registry)
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestSaveLoad(JitTestCase):
def test_versioned_symbols(self):
"""
Tests Torchscript symbol versioning. See note [Versioned Symbols].
This test uses an undocumented, test-only function
torch._test_serialization_subcmul.
This function is implemented as (a - alpha * b) with a default value
of 1 for alpha. In file format version 2, however, it was implemented
as (b - alpha * a) with a default value of 2 for alpha.
This test verifies a module seralized with file format version 2
exhibits the old behavior, and that the same module newly serialized
exhibits the current behavior.
"""
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
def forward(self, a, b, alpha: float):
no_alpha = torch._test_serialization_subcmul(a, b)
with_alpha = torch._test_serialization_subcmul(a, b, alpha)
return no_alpha, with_alpha
def historic_subcmul(a, b, alpha=2):
return b - alpha * a
def current_subcmul(a, b, alpha=1):
return a - alpha * b
# Loads and verifies the historic behavior of the module
# that was serialized with version 2
module_v2 = torch.jit.load(pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt")
a = torch.randn((5,))
b = torch.randn((5,))
alpha = random.random()
args = (a, b, alpha)
no_alpha_v2, with_alpha_v2 = module_v2(*args)
self.assertEqual(no_alpha_v2, historic_subcmul(a, b))
self.assertEqual(with_alpha_v2, historic_subcmul(*args))
# Scripts, saves, loads and verifies the current behavior of the module
scripted_module = torch.jit.script(MyModule())
buffer = io.BytesIO()
torch.jit.save(scripted_module, buffer)
buffer.seek(0)
module_current = torch.jit.load(buffer)
no_alpha_current, with_alpha_current = module_current(*args)
self.assertEqual(no_alpha_current, current_subcmul(a, b))
self.assertEqual(with_alpha_current, current_subcmul(*args))
# Helper that returns the module after saving and loading
def _save_load_module(self, m):
scripted_module = torch.jit.script(m())
buffer = io.BytesIO()
torch.jit.save(scripted_module, buffer)
buffer.seek(0)
return torch.jit.load(buffer)
# Helper which returns the result of a function or the exception the
# function threw.
def _try_fn(self, fn, *args, **kwargs):
try:
return fn(*args, **kwargs)
except Exception as e:
return e
def _verify_no(self, kind, m):
self._verify_count(kind, m, 0)
def _verify_count(self, kind, m, count):
node_count = sum(str(n).count(kind) for n in m.graph.nodes())
self.assertEqual(node_count, count)
"""
Tests that verify Torchscript remaps aten::div(_) from versions 0-3
to call either aten::true_divide(_), if an input is a float type,
or truncated aten::divide(_) otherwise.
NOTE: currently compares against current div behavior, too, since
div behavior has not yet been updated.
"""
def test_versioned_div_tensor(self):
def historic_div(self, other):
if self.is_floating_point() or other.is_floating_point():
return self.true_divide(other)
return self.divide(other, rounding_mode='trunc')
# Tensor x Tensor
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
def forward(self, a, b):
result_0 = a / b
result_1 = torch.div(a, b)
result_2 = a.div(b)
return result_0, result_1, result_2
# Loads historic module
try:
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_v3.pt")
except Exception as e:
self.skipTest("Failed to load fixture!")
self._verify_count("aten::div", v3_module, 6) # true_divide and divide alias to div
self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument
current_module = self._save_load_module(MyModule)
self._verify_count("aten::div", current_module, 3)
vals = (2., 3., 2, 3)
for val_a, val_b in product(vals, vals):
a = torch.tensor((val_a,))
b = torch.tensor((val_b,))
def _helper(m, fn):
m_results = self._try_fn(m, a, b)
fn_result = self._try_fn(fn, a, b)
if isinstance(m_results, Exception):
self.assertTrue(isinstance(fn_result, Exception))
else:
for result in m_results:
self.assertEqual(result, fn_result)
_helper(v3_module, historic_div)
_helper(current_module, torch.div)
def test_versioned_div_tensor_inplace(self):
def historic_div_(self, other):
if self.is_floating_point() or other.is_floating_point():
return self.true_divide_(other)
return self.divide_(other, rounding_mode='trunc')
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
def forward(self, a, b):
a /= b
return a
try:
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_inplace_v3.pt")
except Exception as e:
self.skipTest("Failed to load fixture!")
self._verify_count("aten::div", v3_module, 2) # true_divide and divide both alias to div
self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument
current_module = self._save_load_module(MyModule)
self._verify_count("aten::div", current_module, 1)
vals = (2., 3., 2, 3)
for val_a, val_b in product(vals, vals):
a = torch.tensor((val_a,))
b = torch.tensor((val_b,))
def _helper(m, fn):
fn_result = self._try_fn(fn, a.clone(), b)
m_result = self._try_fn(m, a, b)
if isinstance(m_result, Exception):
self.assertTrue(fn_result, Exception)
else:
self.assertEqual(m_result, fn_result)
self.assertEqual(m_result, a)
_helper(v3_module, historic_div_)
# Recreates a since it was modified in place
a = torch.tensor((val_a,))
_helper(current_module, torch.Tensor.div_)
def test_versioned_div_tensor_out(self):
def historic_div_out(self, other, out):
if self.is_floating_point() or other.is_floating_point() or out.is_floating_point():
return torch.true_divide(self, other, out=out)
return torch.divide(self, other, out=out, rounding_mode='trunc')
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
def forward(self, a, b, out):
return a.div(b, out=out)
try:
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_out_v3.pt")
except Exception as e:
self.skipTest("Failed to load fixture!")
self._verify_count("aten::div", v3_module, 2) # true_divide and divide alias to div
self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument
current_module = self._save_load_module(MyModule)
self._verify_count("aten::div", current_module, 1)
vals = (2., 3., 2, 3)
for val_a, val_b in product(vals, vals):
a = torch.tensor((val_a,))
b = torch.tensor((val_b,))
for out in (torch.empty((1,)), torch.empty((1,), dtype=torch.long)):
def _helper(m, fn):
fn_result = None
if fn is torch.div:
fn_result = self._try_fn(fn, a, b, out=out.clone())
else:
fn_result = self._try_fn(fn, a, b, out.clone())
m_result = self._try_fn(m, a, b, out)
if isinstance(m_result, Exception):
self.assertTrue(fn_result, Exception)
else:
self.assertEqual(m_result, fn_result)
self.assertEqual(m_result, out)
_helper(v3_module, historic_div_out)
_helper(current_module, torch.div)
def test_versioned_div_scalar(self):
def historic_div_scalar_float(self, other: float):
return torch.true_divide(self, other)
def historic_div_scalar_int(self, other: int):
if self.is_floating_point():
return torch.true_divide(self, other)
return torch.divide(self, other, rounding_mode='trunc')
class MyModuleFloat(torch.nn.Module):
def __init__(self):
super(MyModuleFloat, self).__init__()
def forward(self, a, b: float):
return a / b
class MyModuleInt(torch.nn.Module):
def __init__(self):
super(MyModuleInt, self).__init__()
def forward(self, a, b: int):
return a / b
try:
v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v3.pt")
v3_module_int = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_int_v3.pt")
except Exception as e:
self.skipTest("Failed to load fixture!")
for m in (v3_module_float, v3_module_int):
self._verify_count("aten::div", m, 2) # true_divide and divide alias to div
self._verify_count('prim::Constant[value="trunc"]', m, 1) # rounding_mode argument
current_module_float = self._save_load_module(MyModuleFloat)
current_module_int = self._save_load_module(MyModuleInt)
for m in (current_module_float, current_module_int):
self._verify_count("aten::div", m, 1)
vals = (2., 3., 2, 3)
for val_a, val_b in product(vals, vals):
a = torch.tensor((val_a,))
b = val_b
def _helper(m, fn):
m_result = self._try_fn(m, a, b)
fn_result = self._try_fn(fn, a, b)
if isinstance(m_result, Exception):
self.assertTrue(fn_result, Exception)
else:
self.assertEqual(m_result, fn_result)
if isinstance(b, float):
_helper(v3_module_float, historic_div_scalar_float)
_helper(current_module_float, torch.div)
else:
_helper(v3_module_int, historic_div_scalar_int)
_helper(current_module_int, torch.div)
def test_versioned_div_scalar_reciprocal(self):
def historic_div_scalar_float_reciprocal(self, other: float):
return other / self
def historic_div_scalar_int_reciprocal(self, other: int):
if self.is_floating_point():
return other / self
return torch.divide(other, self, rounding_mode='trunc')
class MyModuleFloat(torch.nn.Module):
def __init__(self):
super(MyModuleFloat, self).__init__()
def forward(self, a, b: float):
return b / a
class MyModuleInt(torch.nn.Module):
def __init__(self):
super(MyModuleInt, self).__init__()
def forward(self, a, b: int):
return b / a
try:
v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_reciprocal_float_v3.pt")
v3_module_int = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_reciprocal_int_v3.pt")
except Exception as e:
self.skipTest("Failed to load fixture!")
# NOTE: number / tensor is rewritten to torch.reciprocal(a) * b
# so true_divide and floor_divide do not appear in their graphs
for m in (v3_module_float, v3_module_int):
self._verify_no("aten::div", m)
self._verify_no("aten::true_divide", m)
self._verify_no("aten::floor_divide", m)
self._verify_count("aten::reciprocal", m, 1)
current_module_float = self._save_load_module(MyModuleFloat)
current_module_int = self._save_load_module(MyModuleInt)
vals = (2., 3., 2, 3)
for val_a, val_b in product(vals, vals):
a = torch.tensor((val_a,))
b = val_b
def _helper(m, fn):
m_result = self._try_fn(m, a, b)
fn_result = None
# Reverses argument order for torch.div
if fn is torch.div:
fn_result = self._try_fn(torch.div, b, a)
else:
fn_result = self._try_fn(fn, a, b)
if isinstance(m_result, Exception):
self.assertTrue(isinstance(fn_result, Exception))
elif fn is torch.div or a.is_floating_point():
self.assertEqual(m_result, fn_result)
else:
# Skip when fn is not torch.div and a is integral because
# historic_div_scalar_int performs floored division
pass
if isinstance(b, float):
_helper(v3_module_float, historic_div_scalar_float_reciprocal)
_helper(current_module_float, torch.div)
else:
_helper(v3_module_int, historic_div_scalar_int_reciprocal)
_helper(current_module_int, torch.div)
def test_versioned_div_scalar_inplace(self):
def historic_div_scalar_float_inplace(self, other: float):
return self.true_divide_(other)
def historic_div_scalar_int_inplace(self, other: int):
if self.is_floating_point():
return self.true_divide_(other)
return self.divide_(other, rounding_mode='trunc')
class MyModuleFloat(torch.nn.Module):
def __init__(self):
super(MyModuleFloat, self).__init__()
def forward(self, a, b: float):
a /= b
return a
class MyModuleInt(torch.nn.Module):
def __init__(self):
super(MyModuleInt, self).__init__()
def forward(self, a, b: int):
a /= b
return a
try:
v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_inplace_float_v3.pt")
v3_module_int = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_inplace_int_v3.pt")
except Exception as e:
self.skipTest("Failed to load fixture!")
for m in (v3_module_float, v3_module_int):
self._verify_count("aten::div_", m, 2) # true_divide and divide alias to div
self._verify_count('prim::Constant[value="trunc"]', m, 1) # rounding_mode argument
current_module_float = self._save_load_module(MyModuleFloat)
current_module_int = self._save_load_module(MyModuleInt)
for m in (current_module_float, current_module_int):
self._verify_count("aten::div", m, 1)
for m in (current_module_float, current_module_int):
self._verify_count("aten::div", m, 1)
vals = (2., 3., 2, 3)
for val_a, val_b in product(vals, vals):
a = torch.tensor((val_a,))
b = val_b
def _helper(m, fn):
m_result = self._try_fn(m, a, b)
fn_result = self._try_fn(fn, a, b)
if isinstance(m_result, Exception):
self.assertTrue(fn_result, Exception)
else:
self.assertEqual(m_result, fn_result)
if isinstance(b, float):
_helper(v3_module_float, historic_div_scalar_float_inplace)
_helper(current_module_float, torch.Tensor.div_)
else:
_helper(v3_module_int, historic_div_scalar_int_inplace)
_helper(current_module_int, torch.Tensor.div_)
# NOTE: Scalar division was already true division in op version 3,
# so this test verifies the behavior is unchanged.
def test_versioned_div_scalar_scalar(self):
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
def forward(self, a: float, b: int, c: float, d: int):
result_0 = a / b
result_1 = a / c
result_2 = b / c
result_3 = b / d
return (result_0, result_1, result_2, result_3)
try:
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_scalar_v3.pt")
except Exception as e:
self.skipTest("Failed to load fixture!")
self._verify_count("aten::div", v3_module, 4)
current_module = self._save_load_module(MyModule)
self._verify_count("aten::div", current_module, 4)
def _helper(m, fn):
vals = (5., 3, 2., 7)
m_result = m(*vals)
fn_result = fn(*vals)
for mr, hr in zip(m_result, fn_result):
self.assertEqual(mr, hr)
_helper(v3_module, current_module)
# NOTE: the JIT was incapable of handling boolean fill values when
# PyTorch produced file format versions 0-4
def test_versioned_full_integer_value(self):
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
def forward(self, int_fill: int):
size = torch.Size(2, 2)
a = torch.full(size, int_fill)
b = torch.full(size, 1)
return (a, b)
try:
v4_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_full_integer_value_v4.pt")
except Exception as e:
self.skipTest("Failed to load fixture!")
self._verify_count("aten::full", v4_module, 2)
current_module = self._save_load_module(MyModule)
self._verify_count("aten::full", current_module, 2)
# Verifies historic integer type inference is float
# NOTE: only verifies floating point, not exact dtype, due to
# https://github.com/pytorch/pytorch/issues/40470
results = v4_module(2)
for result in results:
self.assertTrue(result.is_floating_point())
# Verifies values are correct
a, b = results
self.assertTrue((a == 2.).all())
self.assertTrue((b == 1.).all())
# Tests that torch.full behavior which is the same from prior versions
# to version 5 is preserved.
# NOTE: while torch.full in eager PyTorch accepts a requires_grad argument,
# it does not in Torchscript (see https://github.com/pytorch/pytorch/issues/40363)
def test_versioned_full_preserved(self):
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
def forward(self, float_fill: float):
size = (2, 2)
a = torch.full(size, 1.)
b = torch.full(size, float_fill)
c = torch.full(size, float_fill, dtype=torch.long)
out = torch.empty(size, dtype=torch.long)
d = torch.full(size, float_fill, out=out)
e = torch.full(size, float_fill, dtype=torch.float16, pin_memory=None,
layout=torch.strided, device='cpu')
return (a, b, c, d, e)
try:
v4_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_full_preserved_v4.pt")
except Exception as e:
self.skipTest("Failed to load fixture!")
self._verify_count("aten::full", v4_module, 5)
current_module = self._save_load_module(MyModule)
self._verify_count("aten::full", current_module, 5)
self.assertEqual(v4_module(2.), current_module(2.))
def test_versioned_symbols_reserialization(self):
"""
Tests that loading and saving serialized Torchscript with a versioned
symbol won't persist the original function and will inline the
versioned builtin.
"""
module_v2 = torch.jit.load(pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt")
buffer = io.BytesIO()
torch.jit.save(module_v2, buffer)
buffer.seek(0)
module_reserialized = torch.jit.load(buffer)
subcmul_nodes = sum("subcmul" in n.kind() for
n in module_reserialized.graph.nodes())
self.assertEqual(subcmul_nodes, 0)
def test_different_modules(self):
"""
Exercise the situation where we have the same qualified name
in two different CompilationUnits on save/load.
"""
class Foo(torch.nn.Module):
def __init__(self):
super(Foo, self).__init__()
self.foo = torch.nn.Linear(2, 2)
self.bar = torch.nn.Linear(2, 2)
def forward(self, x):
x = self.foo(x)
x = self.bar(x)
return x
first_script_module = torch.jit.script(Foo())
first_saved_module = io.BytesIO()
torch.jit.save(first_script_module, first_saved_module)
first_saved_module.seek(0)
clear_class_registry()
class Foo(torch.nn.Module):
def __init__(self):
super(Foo, self).__init__()
self.foo = torch.nn.Linear(2, 2)
def forward(self, x):
x = self.foo(x)
return x
second_script_module = torch.jit.script(Foo())
second_saved_module = io.BytesIO()
torch.jit.save(torch.jit.script(Foo()), second_saved_module)
second_saved_module.seek(0)
clear_class_registry()
self.assertEqual(
first_script_module._c.qualified_name, second_script_module._c.qualified_name
)
class ContainsBoth(torch.nn.Module):
def __init__(self):
super().__init__()
self.add_module("second", torch.jit.load(second_saved_module))
self.add_module("first", torch.jit.load(first_saved_module))
def forward(self, x):
x = self.first(x)
x = self.second(x)
return x
sm = torch.jit.script(ContainsBoth())
contains_both = io.BytesIO()
torch.jit.save(sm, contains_both)
contains_both.seek(0)
sm = torch.jit.load(contains_both)
def test_different_functions(self):
"""
Exercise the situation where we have the same qualified name
in two different CompilationUnits on save/load.
"""
def lol(x):
return x
class Foo(torch.nn.Module):
def forward(self, x):
return lol(x)
first_script_module = torch.jit.script(Foo())
first_saved_module = io.BytesIO()
torch.jit.save(first_script_module, first_saved_module)
first_saved_module.seek(0)
clear_class_registry()
def lol(x): # noqa: F811
return "hello"
class Foo(torch.nn.Module):
def forward(self, x):
return lol(x)
second_script_module = torch.jit.script(Foo())
second_saved_module = io.BytesIO()
torch.jit.save(torch.jit.script(Foo()), second_saved_module)
second_saved_module.seek(0)
clear_class_registry()
self.assertEqual(
first_script_module._c.qualified_name, second_script_module._c.qualified_name
)
class ContainsBoth(torch.nn.Module):
def __init__(self):
super().__init__()
self.add_module("second", torch.jit.load(second_saved_module))
self.add_module("first", torch.jit.load(first_saved_module))
def forward(self, x):
x = self.first(x)
x = self.second(x)
return x
sm = torch.jit.script(ContainsBoth())
contains_both = io.BytesIO()
torch.jit.save(sm, contains_both)
contains_both.seek(0)
sm = torch.jit.load(contains_both)
def test_different_interfaces(self):
"""
Exercise the situation where we have the same qualified name
in two different CompilationUnits on save/load.
"""
@torch.jit.interface
class MyInterface(object):
def bar(self, x: Tensor) -> Tensor:
pass
@torch.jit.script
class ImplementInterface(object):
def __init__(self):
pass
def bar(self, x):
return x
class Foo(torch.nn.Module):
__annotations__ = {"interface": MyInterface}
def __init__(self):
super().__init__()
self.interface = ImplementInterface()
def forward(self, x):
return self.interface.bar(x)
first_script_module = torch.jit.script(Foo())
first_saved_module = io.BytesIO()
torch.jit.save(first_script_module, first_saved_module)
first_saved_module.seek(0)
clear_class_registry()
@torch.jit.interface
class MyInterface(object):
def not_bar(self, x: Tensor) -> Tensor:
pass
@torch.jit.script # noqa: F811
class ImplementInterface(object): # noqa: F811
def __init__(self):
pass
def not_bar(self, x):
return x
class Foo(torch.nn.Module):
__annotations__ = {"interface": MyInterface}
def __init__(self):
super().__init__()
self.interface = ImplementInterface()
def forward(self, x):
return self.interface.not_bar(x)
second_script_module = torch.jit.script(Foo())
second_saved_module = io.BytesIO()
torch.jit.save(torch.jit.script(Foo()), second_saved_module)
second_saved_module.seek(0)
clear_class_registry()
self.assertEqual(
first_script_module._c.qualified_name, second_script_module._c.qualified_name
)
class ContainsBoth(torch.nn.Module):
def __init__(self):
super().__init__()
self.add_module("second", torch.jit.load(second_saved_module))
self.add_module("first", torch.jit.load(first_saved_module))
def forward(self, x):
x = self.first(x)
x = self.second(x)
return x
sm = torch.jit.script(ContainsBoth())
contains_both = io.BytesIO()
torch.jit.save(sm, contains_both)
contains_both.seek(0)
sm = torch.jit.load(contains_both)
def test_many_collisions(self):
class MyCoolNamedTuple(NamedTuple):
a: int
@torch.jit.interface
class MyInterface(object):
def bar(self, x: Tensor) -> Tensor:
pass
@torch.jit.script
class ImplementInterface(object):
def __init__(self):
pass
def bar(self, x):
return x
def lol(x):
return x
class Foo(torch.nn.Module):
interface: MyInterface
def __init__(self):
super().__init__()
self.foo = torch.nn.Linear(2, 2)
self.bar = torch.nn.Linear(2, 2)
self.interface = ImplementInterface()
def forward(self, x):
x = self.foo(x)
x = self.bar(x)
x = lol(x)
x = self.interface.bar(x)
return x, MyCoolNamedTuple(a=5)
first_script_module = torch.jit.script(Foo())
first_saved_module = io.BytesIO()
torch.jit.save(first_script_module, first_saved_module)
first_saved_module.seek(0)
clear_class_registry()
@torch.jit.interface
class MyInterface(object):
def not_bar(self, x: Tensor) -> Tensor:
pass
@torch.jit.script
class ImplementInterface(object): # noqa: F811
def __init__(self):
pass
def not_bar(self, x):
return x
def lol(x): # noqa: F811
return "asdofij"
class MyCoolNamedTuple(NamedTuple): # noqa: F811
a: str
class Foo(torch.nn.Module):
interface: MyInterface
def __init__(self):
super().__init__()
self.foo = torch.nn.Linear(2, 2)
self.interface = ImplementInterface()
def forward(self, x):
x = self.foo(x)
self.interface.not_bar(x)
x = lol(x)
return x, MyCoolNamedTuple(a="hello")
second_script_module = torch.jit.script(Foo())
second_saved_module = io.BytesIO()
torch.jit.save(second_script_module, second_saved_module)
second_saved_module.seek(0)
clear_class_registry()
self.assertEqual(
first_script_module._c.qualified_name, second_script_module._c.qualified_name
)
class ContainsBoth(torch.nn.Module):
def __init__(self):
super().__init__()
self.add_module("second", torch.jit.load(second_saved_module))
self.add_module("first", torch.jit.load(first_saved_module))
def forward(self, x):
x, named_tuple_1 = self.first(x)
x, named_tuple_2 = self.second(x)
return len(x + named_tuple_2.a) + named_tuple_1.a
sm = torch.jit.script(ContainsBoth())
contains_both = io.BytesIO()
torch.jit.save(sm, contains_both)
contains_both.seek(0)
sm = torch.jit.load(contains_both)
def test_save_load_with_extra_files(self):
class MyMod(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, a):
return a
# specifically test binary data
value = b"bar\x00\xffbaz"
expected_extra_files = {}
expected_extra_files['foo'] = value
# verify that str to bytes conversion also works
expected_extra_files['foo2'] = "bar"
m = MyMod()
# Save to file.
with TemporaryFileName() as fname:
m.save(fname, _extra_files=expected_extra_files)
# values don't matter
extra_files = {'foo': '', 'foo2': None}
torch.jit.load(fname, _extra_files=extra_files)
self.assertEqual(value, extra_files['foo'])
# results come back always as bytes
self.assertEqual(b"bar", extra_files['foo2'])
# Use torch.jit API
torch.jit.save(m, fname, _extra_files=expected_extra_files)
extra_files['foo'] = ''
torch.jit.load(fname, _extra_files=extra_files)
self.assertEqual(value, extra_files['foo'])
# Save to buffer.
buffer = io.BytesIO(m.save_to_buffer(_extra_files=expected_extra_files))
extra_files = {'foo': ''}
torch.jit.load(buffer, _extra_files=extra_files)
self.assertEqual(value, extra_files['foo'])
# Use torch.jit API
buffer = io.BytesIO()
torch.jit.save(m, buffer, _extra_files=expected_extra_files)
buffer.seek(0)
extra_files = {'foo': ''}
torch.jit.load(buffer, _extra_files=extra_files)
self.assertEqual(value, extra_files['foo'])
# Non-existent file 'bar'
with self.assertRaises(RuntimeError):
extra_files['bar'] = ''
torch.jit.load(buffer, _extra_files=extra_files)
def test_save_load_using_pathlib(self):
class MyMod(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, a):
return 2 * a
m = MyMod()
# Save then load.
with TemporaryFileName() as fname:
path = pathlib.Path(fname)
m.save(path)
m2 = torch.jit.load(path)
x = torch.tensor([1., 2., 3., 4.])
self.assertTrue(torch.equal(m(x), m2(x)))
def test_save_nonexit_file(self):
class Foo(torch.nn.Module):
def forward(self, x):
return 2 * x
script_module = torch.jit.script(Foo())
with self.assertRaises(RuntimeError):
script_module.save("NonExist/path/test.pt")
def test_save_namedtuple_input_only(self):
"""
Even if a NamedTuple is only used as an input argument, saving and
loading should work correctly.
"""
global FooTuple # see [local resolution in python]
class FooTuple(NamedTuple):
a: int
class MyModule(torch.nn.Module):
def forward(self, x: FooTuple) -> torch.Tensor:
return torch.tensor(3)
m_loaded = self.getExportImportCopy(torch.jit.script(MyModule()))
output = m_loaded(FooTuple(a=5))
self.assertEqual(output, torch.tensor(3))
def test_save_namedtuple_output_only(self):
"""
Even if a NamedTuple is only used as an output argument, saving and
loading should work correctly.
"""
global FooTuple # see [local resolution in python]
class FooTuple(NamedTuple):
a: int
class MyModule(torch.nn.Module):
def forward(self) -> Optional[FooTuple]:
return None
m_loaded = self.getExportImportCopy(torch.jit.script(MyModule()))
output = m_loaded()
self.assertEqual(output, None)
def test_save_load_params_buffers_submodules(self):
"""
Check that parameters, buffers, and submodules are the same after loading.
"""
class Submodule(torch.nn.Module):
def __init__(self):
super().__init__()
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.add_module("submodule_a", Submodule())
self.register_parameter("parameter_a", torch.nn.Parameter(torch.randn(4)))
self.register_buffer("buffer", torch.randn(4))
self.t = torch.rand(4) # not buffer
self.parameter_b = torch.nn.Parameter(torch.randn(4))
self.submodule_b = Submodule()
m = TestModule()
m_loaded = self.getExportImportCopy(torch.jit.script(m))
# Check submodules.
self.assertEqual(len(list(m.named_modules())), len(list(m_loaded.named_modules())))
for m_s, loaded_s in zip(m.named_modules(), m_loaded.named_modules()):
m_name, _ = m_s
loaded_name, _ = loaded_s
self.assertEqual(m_name, loaded_name)
# Check parameters.
self.assertEqual(len(list(m.parameters())), len(list(m_loaded.parameters())))
for m_p, loaded_p in zip(m.parameters(), m_loaded.parameters()):
self.assertEqual(m_p, loaded_p)
# Check buffers.
self.assertEqual(len(list(m.named_buffers())), len(list(m_loaded.named_buffers())))
for m_b, loaded_b in zip(m.named_buffers(), m_loaded.named_buffers()):
m_name, m_buffer = m_b
loaded_name, loaded_buffer = loaded_b
self.assertEqual(m_name, loaded_name)
self.assertEqual(m_buffer, loaded_buffer)