mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Summary:
As this diff shows, currently there are a couple hundred instances of raw `noqa` in the codebase, which just ignore all errors on a given line. That isn't great, so this PR changes all existing instances of that antipattern to qualify the `noqa` with respect to a specific error code, and adds a lint to prevent more of this from happening in the future.
Interestingly, some of the examples the `noqa` lint catches are genuine attempts to qualify the `noqa` with a specific error code, such as these two:
```
test/jit/test_misc.py:27: print(f"{hello + ' ' + test}, I'm a {test}") # noqa E999
test/jit/test_misc.py:28: print(f"format blank") # noqa F541
```
However, those are still wrong because they are [missing a colon](https://flake8.pycqa.org/en/3.9.1/user/violations.html#in-line-ignoring-errors), which actually causes the error code to be completely ignored:
- If you change them to anything else, the warnings will still be suppressed.
- If you add the necessary colons then it is revealed that `E261` was also being suppressed, unintentionally:
```
test/jit/test_misc.py:27:57: E261 at least two spaces before inline comment
test/jit/test_misc.py:28:35: E261 at least two spaces before inline comment
```
I did try using [flake8-noqa](https://pypi.org/project/flake8-noqa/) instead of a custom `git grep` lint, but it didn't seem to work. This PR is definitely missing some of the functionality that flake8-noqa is supposed to provide, though, so if someone can figure out how to use it, we should do that instead.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56272
Test Plan:
CI should pass on the tip of this PR, and we know that the lint works because the following CI run (before this PR was finished) failed:
- https://github.com/pytorch/pytorch/runs/2365189927
Reviewed By: janeyx99
Differential Revision: D27830127
Pulled By: samestep
fbshipit-source-id: d6dcf4f945ebd18cd76c46a07f3b408296864fcb
1022 lines
36 KiB
Python
1022 lines
36 KiB
Python
from itertools import product as product
|
|
from typing import NamedTuple, Optional
|
|
import io
|
|
import os
|
|
import pathlib
|
|
import random
|
|
import sys
|
|
|
|
from torch import Tensor
|
|
from torch.testing._internal.common_utils import TemporaryFileName
|
|
import torch
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
from torch.testing._internal.jit_utils import (JitTestCase,
|
|
clear_class_registry)
|
|
|
|
if __name__ == "__main__":
|
|
raise RuntimeError(
|
|
"This test file is not meant to be run directly, use:\n\n"
|
|
"\tpython test/test_jit.py TESTNAME\n\n"
|
|
"instead."
|
|
)
|
|
|
|
class TestSaveLoad(JitTestCase):
|
|
def test_versioned_symbols(self):
|
|
"""
|
|
Tests Torchscript symbol versioning. See note [Versioned Symbols].
|
|
This test uses an undocumented, test-only function
|
|
torch._test_serialization_subcmul.
|
|
|
|
This function is implemented as (a - alpha * b) with a default value
|
|
of 1 for alpha. In file format version 2, however, it was implemented
|
|
as (b - alpha * a) with a default value of 2 for alpha.
|
|
This test verifies a module seralized with file format version 2
|
|
exhibits the old behavior, and that the same module newly serialized
|
|
exhibits the current behavior.
|
|
"""
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
|
|
def forward(self, a, b, alpha: float):
|
|
no_alpha = torch._test_serialization_subcmul(a, b)
|
|
with_alpha = torch._test_serialization_subcmul(a, b, alpha)
|
|
return no_alpha, with_alpha
|
|
|
|
def historic_subcmul(a, b, alpha=2):
|
|
return b - alpha * a
|
|
|
|
def current_subcmul(a, b, alpha=1):
|
|
return a - alpha * b
|
|
|
|
# Loads and verifies the historic behavior of the module
|
|
# that was serialized with version 2
|
|
module_v2 = torch.jit.load(pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt")
|
|
a = torch.randn((5,))
|
|
b = torch.randn((5,))
|
|
alpha = random.random()
|
|
args = (a, b, alpha)
|
|
no_alpha_v2, with_alpha_v2 = module_v2(*args)
|
|
self.assertEqual(no_alpha_v2, historic_subcmul(a, b))
|
|
self.assertEqual(with_alpha_v2, historic_subcmul(*args))
|
|
|
|
# Scripts, saves, loads and verifies the current behavior of the module
|
|
scripted_module = torch.jit.script(MyModule())
|
|
buffer = io.BytesIO()
|
|
torch.jit.save(scripted_module, buffer)
|
|
buffer.seek(0)
|
|
module_current = torch.jit.load(buffer)
|
|
no_alpha_current, with_alpha_current = module_current(*args)
|
|
self.assertEqual(no_alpha_current, current_subcmul(a, b))
|
|
self.assertEqual(with_alpha_current, current_subcmul(*args))
|
|
|
|
# Helper that returns the module after saving and loading
|
|
def _save_load_module(self, m):
|
|
scripted_module = torch.jit.script(m())
|
|
buffer = io.BytesIO()
|
|
torch.jit.save(scripted_module, buffer)
|
|
buffer.seek(0)
|
|
return torch.jit.load(buffer)
|
|
|
|
# Helper which returns the result of a function or the exception the
|
|
# function threw.
|
|
def _try_fn(self, fn, *args, **kwargs):
|
|
try:
|
|
return fn(*args, **kwargs)
|
|
except Exception as e:
|
|
return e
|
|
|
|
def _verify_no(self, kind, m):
|
|
self._verify_count(kind, m, 0)
|
|
|
|
def _verify_count(self, kind, m, count):
|
|
node_count = sum(str(n).count(kind) for n in m.graph.nodes())
|
|
self.assertEqual(node_count, count)
|
|
|
|
"""
|
|
Tests that verify Torchscript remaps aten::div(_) from versions 0-3
|
|
to call either aten::true_divide(_), if an input is a float type,
|
|
or truncated aten::divide(_) otherwise.
|
|
|
|
NOTE: currently compares against current div behavior, too, since
|
|
div behavior has not yet been updated.
|
|
"""
|
|
|
|
def test_versioned_div_tensor(self):
|
|
def historic_div(self, other):
|
|
if self.is_floating_point() or other.is_floating_point():
|
|
return self.true_divide(other)
|
|
return self.divide(other, rounding_mode='trunc')
|
|
|
|
# Tensor x Tensor
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
|
|
def forward(self, a, b):
|
|
result_0 = a / b
|
|
result_1 = torch.div(a, b)
|
|
result_2 = a.div(b)
|
|
|
|
return result_0, result_1, result_2
|
|
|
|
# Loads historic module
|
|
try:
|
|
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_v3.pt")
|
|
except Exception as e:
|
|
self.skipTest("Failed to load fixture!")
|
|
|
|
self._verify_count("aten::div", v3_module, 6) # true_divide and divide alias to div
|
|
self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument
|
|
|
|
current_module = self._save_load_module(MyModule)
|
|
self._verify_count("aten::div", current_module, 3)
|
|
|
|
vals = (2., 3., 2, 3)
|
|
for val_a, val_b in product(vals, vals):
|
|
a = torch.tensor((val_a,))
|
|
b = torch.tensor((val_b,))
|
|
|
|
def _helper(m, fn):
|
|
m_results = self._try_fn(m, a, b)
|
|
fn_result = self._try_fn(fn, a, b)
|
|
|
|
if isinstance(m_results, Exception):
|
|
self.assertTrue(isinstance(fn_result, Exception))
|
|
else:
|
|
for result in m_results:
|
|
self.assertEqual(result, fn_result)
|
|
|
|
_helper(v3_module, historic_div)
|
|
_helper(current_module, torch.div)
|
|
|
|
def test_versioned_div_tensor_inplace(self):
|
|
def historic_div_(self, other):
|
|
if self.is_floating_point() or other.is_floating_point():
|
|
return self.true_divide_(other)
|
|
return self.divide_(other, rounding_mode='trunc')
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
|
|
def forward(self, a, b):
|
|
a /= b
|
|
return a
|
|
|
|
try:
|
|
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_inplace_v3.pt")
|
|
except Exception as e:
|
|
self.skipTest("Failed to load fixture!")
|
|
|
|
self._verify_count("aten::div", v3_module, 2) # true_divide and divide both alias to div
|
|
self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument
|
|
|
|
current_module = self._save_load_module(MyModule)
|
|
self._verify_count("aten::div", current_module, 1)
|
|
|
|
vals = (2., 3., 2, 3)
|
|
for val_a, val_b in product(vals, vals):
|
|
a = torch.tensor((val_a,))
|
|
b = torch.tensor((val_b,))
|
|
|
|
def _helper(m, fn):
|
|
fn_result = self._try_fn(fn, a.clone(), b)
|
|
m_result = self._try_fn(m, a, b)
|
|
|
|
if isinstance(m_result, Exception):
|
|
self.assertTrue(fn_result, Exception)
|
|
else:
|
|
self.assertEqual(m_result, fn_result)
|
|
self.assertEqual(m_result, a)
|
|
|
|
_helper(v3_module, historic_div_)
|
|
|
|
# Recreates a since it was modified in place
|
|
a = torch.tensor((val_a,))
|
|
_helper(current_module, torch.Tensor.div_)
|
|
|
|
def test_versioned_div_tensor_out(self):
|
|
def historic_div_out(self, other, out):
|
|
if self.is_floating_point() or other.is_floating_point() or out.is_floating_point():
|
|
return torch.true_divide(self, other, out=out)
|
|
return torch.divide(self, other, out=out, rounding_mode='trunc')
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
|
|
def forward(self, a, b, out):
|
|
return a.div(b, out=out)
|
|
|
|
try:
|
|
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_out_v3.pt")
|
|
except Exception as e:
|
|
self.skipTest("Failed to load fixture!")
|
|
|
|
self._verify_count("aten::div", v3_module, 2) # true_divide and divide alias to div
|
|
self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument
|
|
|
|
current_module = self._save_load_module(MyModule)
|
|
self._verify_count("aten::div", current_module, 1)
|
|
|
|
vals = (2., 3., 2, 3)
|
|
for val_a, val_b in product(vals, vals):
|
|
a = torch.tensor((val_a,))
|
|
b = torch.tensor((val_b,))
|
|
|
|
for out in (torch.empty((1,)), torch.empty((1,), dtype=torch.long)):
|
|
def _helper(m, fn):
|
|
fn_result = None
|
|
if fn is torch.div:
|
|
fn_result = self._try_fn(fn, a, b, out=out.clone())
|
|
else:
|
|
fn_result = self._try_fn(fn, a, b, out.clone())
|
|
m_result = self._try_fn(m, a, b, out)
|
|
|
|
if isinstance(m_result, Exception):
|
|
self.assertTrue(fn_result, Exception)
|
|
else:
|
|
self.assertEqual(m_result, fn_result)
|
|
self.assertEqual(m_result, out)
|
|
|
|
_helper(v3_module, historic_div_out)
|
|
_helper(current_module, torch.div)
|
|
|
|
def test_versioned_div_scalar(self):
|
|
def historic_div_scalar_float(self, other: float):
|
|
return torch.true_divide(self, other)
|
|
|
|
def historic_div_scalar_int(self, other: int):
|
|
if self.is_floating_point():
|
|
return torch.true_divide(self, other)
|
|
return torch.divide(self, other, rounding_mode='trunc')
|
|
|
|
class MyModuleFloat(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModuleFloat, self).__init__()
|
|
|
|
def forward(self, a, b: float):
|
|
return a / b
|
|
|
|
class MyModuleInt(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModuleInt, self).__init__()
|
|
|
|
def forward(self, a, b: int):
|
|
return a / b
|
|
|
|
try:
|
|
v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v3.pt")
|
|
v3_module_int = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_int_v3.pt")
|
|
except Exception as e:
|
|
self.skipTest("Failed to load fixture!")
|
|
|
|
for m in (v3_module_float, v3_module_int):
|
|
self._verify_count("aten::div", m, 2) # true_divide and divide alias to div
|
|
self._verify_count('prim::Constant[value="trunc"]', m, 1) # rounding_mode argument
|
|
|
|
current_module_float = self._save_load_module(MyModuleFloat)
|
|
current_module_int = self._save_load_module(MyModuleInt)
|
|
|
|
for m in (current_module_float, current_module_int):
|
|
self._verify_count("aten::div", m, 1)
|
|
|
|
vals = (2., 3., 2, 3)
|
|
for val_a, val_b in product(vals, vals):
|
|
a = torch.tensor((val_a,))
|
|
b = val_b
|
|
|
|
def _helper(m, fn):
|
|
m_result = self._try_fn(m, a, b)
|
|
fn_result = self._try_fn(fn, a, b)
|
|
|
|
if isinstance(m_result, Exception):
|
|
self.assertTrue(fn_result, Exception)
|
|
else:
|
|
self.assertEqual(m_result, fn_result)
|
|
|
|
if isinstance(b, float):
|
|
_helper(v3_module_float, historic_div_scalar_float)
|
|
_helper(current_module_float, torch.div)
|
|
else:
|
|
_helper(v3_module_int, historic_div_scalar_int)
|
|
_helper(current_module_int, torch.div)
|
|
|
|
def test_versioned_div_scalar_reciprocal(self):
|
|
def historic_div_scalar_float_reciprocal(self, other: float):
|
|
return other / self
|
|
|
|
def historic_div_scalar_int_reciprocal(self, other: int):
|
|
if self.is_floating_point():
|
|
return other / self
|
|
return torch.divide(other, self, rounding_mode='trunc')
|
|
|
|
class MyModuleFloat(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModuleFloat, self).__init__()
|
|
|
|
def forward(self, a, b: float):
|
|
return b / a
|
|
|
|
class MyModuleInt(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModuleInt, self).__init__()
|
|
|
|
def forward(self, a, b: int):
|
|
return b / a
|
|
|
|
try:
|
|
v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_reciprocal_float_v3.pt")
|
|
v3_module_int = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_reciprocal_int_v3.pt")
|
|
except Exception as e:
|
|
self.skipTest("Failed to load fixture!")
|
|
|
|
# NOTE: number / tensor is rewritten to torch.reciprocal(a) * b
|
|
# so true_divide and floor_divide do not appear in their graphs
|
|
for m in (v3_module_float, v3_module_int):
|
|
self._verify_no("aten::div", m)
|
|
self._verify_no("aten::true_divide", m)
|
|
self._verify_no("aten::floor_divide", m)
|
|
self._verify_count("aten::reciprocal", m, 1)
|
|
|
|
current_module_float = self._save_load_module(MyModuleFloat)
|
|
current_module_int = self._save_load_module(MyModuleInt)
|
|
|
|
vals = (2., 3., 2, 3)
|
|
for val_a, val_b in product(vals, vals):
|
|
a = torch.tensor((val_a,))
|
|
b = val_b
|
|
|
|
def _helper(m, fn):
|
|
m_result = self._try_fn(m, a, b)
|
|
fn_result = None
|
|
# Reverses argument order for torch.div
|
|
if fn is torch.div:
|
|
fn_result = self._try_fn(torch.div, b, a)
|
|
else:
|
|
fn_result = self._try_fn(fn, a, b)
|
|
|
|
if isinstance(m_result, Exception):
|
|
self.assertTrue(isinstance(fn_result, Exception))
|
|
elif fn is torch.div or a.is_floating_point():
|
|
self.assertEqual(m_result, fn_result)
|
|
else:
|
|
# Skip when fn is not torch.div and a is integral because
|
|
# historic_div_scalar_int performs floored division
|
|
pass
|
|
|
|
if isinstance(b, float):
|
|
_helper(v3_module_float, historic_div_scalar_float_reciprocal)
|
|
_helper(current_module_float, torch.div)
|
|
else:
|
|
_helper(v3_module_int, historic_div_scalar_int_reciprocal)
|
|
_helper(current_module_int, torch.div)
|
|
|
|
def test_versioned_div_scalar_inplace(self):
|
|
def historic_div_scalar_float_inplace(self, other: float):
|
|
return self.true_divide_(other)
|
|
|
|
def historic_div_scalar_int_inplace(self, other: int):
|
|
if self.is_floating_point():
|
|
return self.true_divide_(other)
|
|
|
|
return self.divide_(other, rounding_mode='trunc')
|
|
|
|
class MyModuleFloat(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModuleFloat, self).__init__()
|
|
|
|
def forward(self, a, b: float):
|
|
a /= b
|
|
return a
|
|
|
|
class MyModuleInt(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModuleInt, self).__init__()
|
|
|
|
def forward(self, a, b: int):
|
|
a /= b
|
|
return a
|
|
|
|
try:
|
|
v3_module_float = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_inplace_float_v3.pt")
|
|
v3_module_int = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_inplace_int_v3.pt")
|
|
except Exception as e:
|
|
self.skipTest("Failed to load fixture!")
|
|
|
|
for m in (v3_module_float, v3_module_int):
|
|
self._verify_count("aten::div_", m, 2) # true_divide and divide alias to div
|
|
self._verify_count('prim::Constant[value="trunc"]', m, 1) # rounding_mode argument
|
|
|
|
current_module_float = self._save_load_module(MyModuleFloat)
|
|
current_module_int = self._save_load_module(MyModuleInt)
|
|
|
|
for m in (current_module_float, current_module_int):
|
|
self._verify_count("aten::div", m, 1)
|
|
|
|
for m in (current_module_float, current_module_int):
|
|
self._verify_count("aten::div", m, 1)
|
|
|
|
vals = (2., 3., 2, 3)
|
|
for val_a, val_b in product(vals, vals):
|
|
a = torch.tensor((val_a,))
|
|
b = val_b
|
|
|
|
def _helper(m, fn):
|
|
m_result = self._try_fn(m, a, b)
|
|
fn_result = self._try_fn(fn, a, b)
|
|
|
|
if isinstance(m_result, Exception):
|
|
self.assertTrue(fn_result, Exception)
|
|
else:
|
|
self.assertEqual(m_result, fn_result)
|
|
|
|
if isinstance(b, float):
|
|
_helper(v3_module_float, historic_div_scalar_float_inplace)
|
|
_helper(current_module_float, torch.Tensor.div_)
|
|
else:
|
|
_helper(v3_module_int, historic_div_scalar_int_inplace)
|
|
_helper(current_module_int, torch.Tensor.div_)
|
|
|
|
# NOTE: Scalar division was already true division in op version 3,
|
|
# so this test verifies the behavior is unchanged.
|
|
def test_versioned_div_scalar_scalar(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
|
|
def forward(self, a: float, b: int, c: float, d: int):
|
|
result_0 = a / b
|
|
result_1 = a / c
|
|
result_2 = b / c
|
|
result_3 = b / d
|
|
return (result_0, result_1, result_2, result_3)
|
|
|
|
try:
|
|
v3_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_scalar_v3.pt")
|
|
except Exception as e:
|
|
self.skipTest("Failed to load fixture!")
|
|
|
|
self._verify_count("aten::div", v3_module, 4)
|
|
|
|
current_module = self._save_load_module(MyModule)
|
|
self._verify_count("aten::div", current_module, 4)
|
|
|
|
def _helper(m, fn):
|
|
vals = (5., 3, 2., 7)
|
|
m_result = m(*vals)
|
|
fn_result = fn(*vals)
|
|
for mr, hr in zip(m_result, fn_result):
|
|
self.assertEqual(mr, hr)
|
|
|
|
_helper(v3_module, current_module)
|
|
|
|
# NOTE: the JIT was incapable of handling boolean fill values when
|
|
# PyTorch produced file format versions 0-4
|
|
def test_versioned_full_integer_value(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
|
|
def forward(self, int_fill: int):
|
|
size = torch.Size(2, 2)
|
|
a = torch.full(size, int_fill)
|
|
b = torch.full(size, 1)
|
|
return (a, b)
|
|
|
|
try:
|
|
v4_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_full_integer_value_v4.pt")
|
|
except Exception as e:
|
|
self.skipTest("Failed to load fixture!")
|
|
|
|
self._verify_count("aten::full", v4_module, 2)
|
|
|
|
current_module = self._save_load_module(MyModule)
|
|
self._verify_count("aten::full", current_module, 2)
|
|
|
|
# Verifies historic integer type inference is float
|
|
# NOTE: only verifies floating point, not exact dtype, due to
|
|
# https://github.com/pytorch/pytorch/issues/40470
|
|
results = v4_module(2)
|
|
for result in results:
|
|
self.assertTrue(result.is_floating_point())
|
|
|
|
# Verifies values are correct
|
|
a, b = results
|
|
self.assertTrue((a == 2.).all())
|
|
self.assertTrue((b == 1.).all())
|
|
|
|
# Tests that torch.full behavior which is the same from prior versions
|
|
# to version 5 is preserved.
|
|
# NOTE: while torch.full in eager PyTorch accepts a requires_grad argument,
|
|
# it does not in Torchscript (see https://github.com/pytorch/pytorch/issues/40363)
|
|
def test_versioned_full_preserved(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
|
|
def forward(self, float_fill: float):
|
|
size = (2, 2)
|
|
a = torch.full(size, 1.)
|
|
b = torch.full(size, float_fill)
|
|
c = torch.full(size, float_fill, dtype=torch.long)
|
|
|
|
out = torch.empty(size, dtype=torch.long)
|
|
d = torch.full(size, float_fill, out=out)
|
|
|
|
e = torch.full(size, float_fill, dtype=torch.float16, pin_memory=None,
|
|
layout=torch.strided, device='cpu')
|
|
return (a, b, c, d, e)
|
|
|
|
try:
|
|
v4_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_full_preserved_v4.pt")
|
|
except Exception as e:
|
|
self.skipTest("Failed to load fixture!")
|
|
|
|
self._verify_count("aten::full", v4_module, 5)
|
|
|
|
current_module = self._save_load_module(MyModule)
|
|
self._verify_count("aten::full", current_module, 5)
|
|
|
|
self.assertEqual(v4_module(2.), current_module(2.))
|
|
|
|
def test_versioned_symbols_reserialization(self):
|
|
"""
|
|
Tests that loading and saving serialized Torchscript with a versioned
|
|
symbol won't persist the original function and will inline the
|
|
versioned builtin.
|
|
"""
|
|
module_v2 = torch.jit.load(pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt")
|
|
buffer = io.BytesIO()
|
|
torch.jit.save(module_v2, buffer)
|
|
buffer.seek(0)
|
|
module_reserialized = torch.jit.load(buffer)
|
|
|
|
subcmul_nodes = sum("subcmul" in n.kind() for
|
|
n in module_reserialized.graph.nodes())
|
|
self.assertEqual(subcmul_nodes, 0)
|
|
|
|
def test_different_modules(self):
|
|
"""
|
|
Exercise the situation where we have the same qualified name
|
|
in two different CompilationUnits on save/load.
|
|
"""
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Foo, self).__init__()
|
|
self.foo = torch.nn.Linear(2, 2)
|
|
self.bar = torch.nn.Linear(2, 2)
|
|
|
|
def forward(self, x):
|
|
x = self.foo(x)
|
|
x = self.bar(x)
|
|
return x
|
|
|
|
first_script_module = torch.jit.script(Foo())
|
|
first_saved_module = io.BytesIO()
|
|
torch.jit.save(first_script_module, first_saved_module)
|
|
first_saved_module.seek(0)
|
|
|
|
clear_class_registry()
|
|
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Foo, self).__init__()
|
|
self.foo = torch.nn.Linear(2, 2)
|
|
|
|
def forward(self, x):
|
|
x = self.foo(x)
|
|
return x
|
|
|
|
second_script_module = torch.jit.script(Foo())
|
|
second_saved_module = io.BytesIO()
|
|
torch.jit.save(torch.jit.script(Foo()), second_saved_module)
|
|
second_saved_module.seek(0)
|
|
|
|
clear_class_registry()
|
|
|
|
self.assertEqual(
|
|
first_script_module._c.qualified_name, second_script_module._c.qualified_name
|
|
)
|
|
|
|
class ContainsBoth(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.add_module("second", torch.jit.load(second_saved_module))
|
|
self.add_module("first", torch.jit.load(first_saved_module))
|
|
|
|
def forward(self, x):
|
|
x = self.first(x)
|
|
x = self.second(x)
|
|
return x
|
|
|
|
sm = torch.jit.script(ContainsBoth())
|
|
contains_both = io.BytesIO()
|
|
torch.jit.save(sm, contains_both)
|
|
contains_both.seek(0)
|
|
sm = torch.jit.load(contains_both)
|
|
|
|
def test_different_functions(self):
|
|
"""
|
|
Exercise the situation where we have the same qualified name
|
|
in two different CompilationUnits on save/load.
|
|
"""
|
|
def lol(x):
|
|
return x
|
|
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x):
|
|
return lol(x)
|
|
|
|
first_script_module = torch.jit.script(Foo())
|
|
first_saved_module = io.BytesIO()
|
|
torch.jit.save(first_script_module, first_saved_module)
|
|
first_saved_module.seek(0)
|
|
|
|
clear_class_registry()
|
|
|
|
def lol(x): # noqa: F811
|
|
return "hello"
|
|
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x):
|
|
return lol(x)
|
|
|
|
second_script_module = torch.jit.script(Foo())
|
|
second_saved_module = io.BytesIO()
|
|
torch.jit.save(torch.jit.script(Foo()), second_saved_module)
|
|
second_saved_module.seek(0)
|
|
|
|
clear_class_registry()
|
|
|
|
self.assertEqual(
|
|
first_script_module._c.qualified_name, second_script_module._c.qualified_name
|
|
)
|
|
|
|
class ContainsBoth(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.add_module("second", torch.jit.load(second_saved_module))
|
|
self.add_module("first", torch.jit.load(first_saved_module))
|
|
|
|
def forward(self, x):
|
|
x = self.first(x)
|
|
x = self.second(x)
|
|
return x
|
|
|
|
sm = torch.jit.script(ContainsBoth())
|
|
contains_both = io.BytesIO()
|
|
torch.jit.save(sm, contains_both)
|
|
contains_both.seek(0)
|
|
sm = torch.jit.load(contains_both)
|
|
|
|
def test_different_interfaces(self):
|
|
"""
|
|
Exercise the situation where we have the same qualified name
|
|
in two different CompilationUnits on save/load.
|
|
"""
|
|
@torch.jit.interface
|
|
class MyInterface(object):
|
|
def bar(self, x: Tensor) -> Tensor:
|
|
pass
|
|
|
|
@torch.jit.script
|
|
class ImplementInterface(object):
|
|
def __init__(self):
|
|
pass
|
|
|
|
def bar(self, x):
|
|
return x
|
|
|
|
class Foo(torch.nn.Module):
|
|
__annotations__ = {"interface": MyInterface}
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.interface = ImplementInterface()
|
|
|
|
def forward(self, x):
|
|
return self.interface.bar(x)
|
|
|
|
first_script_module = torch.jit.script(Foo())
|
|
first_saved_module = io.BytesIO()
|
|
torch.jit.save(first_script_module, first_saved_module)
|
|
first_saved_module.seek(0)
|
|
|
|
clear_class_registry()
|
|
|
|
@torch.jit.interface
|
|
class MyInterface(object):
|
|
def not_bar(self, x: Tensor) -> Tensor:
|
|
pass
|
|
|
|
@torch.jit.script # noqa: F811
|
|
class ImplementInterface(object): # noqa: F811
|
|
def __init__(self):
|
|
pass
|
|
|
|
def not_bar(self, x):
|
|
return x
|
|
|
|
class Foo(torch.nn.Module):
|
|
__annotations__ = {"interface": MyInterface}
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.interface = ImplementInterface()
|
|
|
|
def forward(self, x):
|
|
return self.interface.not_bar(x)
|
|
|
|
second_script_module = torch.jit.script(Foo())
|
|
second_saved_module = io.BytesIO()
|
|
torch.jit.save(torch.jit.script(Foo()), second_saved_module)
|
|
second_saved_module.seek(0)
|
|
|
|
clear_class_registry()
|
|
|
|
self.assertEqual(
|
|
first_script_module._c.qualified_name, second_script_module._c.qualified_name
|
|
)
|
|
|
|
class ContainsBoth(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.add_module("second", torch.jit.load(second_saved_module))
|
|
self.add_module("first", torch.jit.load(first_saved_module))
|
|
|
|
def forward(self, x):
|
|
x = self.first(x)
|
|
x = self.second(x)
|
|
return x
|
|
|
|
sm = torch.jit.script(ContainsBoth())
|
|
contains_both = io.BytesIO()
|
|
torch.jit.save(sm, contains_both)
|
|
contains_both.seek(0)
|
|
sm = torch.jit.load(contains_both)
|
|
|
|
def test_many_collisions(self):
|
|
class MyCoolNamedTuple(NamedTuple):
|
|
a: int
|
|
|
|
@torch.jit.interface
|
|
class MyInterface(object):
|
|
def bar(self, x: Tensor) -> Tensor:
|
|
pass
|
|
|
|
@torch.jit.script
|
|
class ImplementInterface(object):
|
|
def __init__(self):
|
|
pass
|
|
|
|
def bar(self, x):
|
|
return x
|
|
|
|
def lol(x):
|
|
return x
|
|
|
|
class Foo(torch.nn.Module):
|
|
interface: MyInterface
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.foo = torch.nn.Linear(2, 2)
|
|
self.bar = torch.nn.Linear(2, 2)
|
|
self.interface = ImplementInterface()
|
|
|
|
def forward(self, x):
|
|
x = self.foo(x)
|
|
x = self.bar(x)
|
|
x = lol(x)
|
|
x = self.interface.bar(x)
|
|
|
|
return x, MyCoolNamedTuple(a=5)
|
|
|
|
|
|
first_script_module = torch.jit.script(Foo())
|
|
first_saved_module = io.BytesIO()
|
|
torch.jit.save(first_script_module, first_saved_module)
|
|
first_saved_module.seek(0)
|
|
|
|
clear_class_registry()
|
|
|
|
@torch.jit.interface
|
|
class MyInterface(object):
|
|
def not_bar(self, x: Tensor) -> Tensor:
|
|
pass
|
|
|
|
@torch.jit.script
|
|
class ImplementInterface(object): # noqa: F811
|
|
def __init__(self):
|
|
pass
|
|
|
|
def not_bar(self, x):
|
|
return x
|
|
|
|
def lol(x): # noqa: F811
|
|
return "asdofij"
|
|
|
|
class MyCoolNamedTuple(NamedTuple): # noqa: F811
|
|
a: str
|
|
|
|
class Foo(torch.nn.Module):
|
|
interface: MyInterface
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.foo = torch.nn.Linear(2, 2)
|
|
self.interface = ImplementInterface()
|
|
|
|
def forward(self, x):
|
|
x = self.foo(x)
|
|
self.interface.not_bar(x)
|
|
x = lol(x)
|
|
return x, MyCoolNamedTuple(a="hello")
|
|
|
|
second_script_module = torch.jit.script(Foo())
|
|
second_saved_module = io.BytesIO()
|
|
torch.jit.save(second_script_module, second_saved_module)
|
|
second_saved_module.seek(0)
|
|
|
|
clear_class_registry()
|
|
|
|
self.assertEqual(
|
|
first_script_module._c.qualified_name, second_script_module._c.qualified_name
|
|
)
|
|
|
|
class ContainsBoth(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.add_module("second", torch.jit.load(second_saved_module))
|
|
self.add_module("first", torch.jit.load(first_saved_module))
|
|
|
|
def forward(self, x):
|
|
x, named_tuple_1 = self.first(x)
|
|
x, named_tuple_2 = self.second(x)
|
|
return len(x + named_tuple_2.a) + named_tuple_1.a
|
|
|
|
sm = torch.jit.script(ContainsBoth())
|
|
contains_both = io.BytesIO()
|
|
torch.jit.save(sm, contains_both)
|
|
contains_both.seek(0)
|
|
sm = torch.jit.load(contains_both)
|
|
|
|
def test_save_load_with_extra_files(self):
|
|
class MyMod(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, a):
|
|
return a
|
|
|
|
# specifically test binary data
|
|
value = b"bar\x00\xffbaz"
|
|
|
|
expected_extra_files = {}
|
|
expected_extra_files['foo'] = value
|
|
# verify that str to bytes conversion also works
|
|
expected_extra_files['foo2'] = "bar"
|
|
m = MyMod()
|
|
|
|
# Save to file.
|
|
with TemporaryFileName() as fname:
|
|
m.save(fname, _extra_files=expected_extra_files)
|
|
# values don't matter
|
|
extra_files = {'foo': '', 'foo2': None}
|
|
torch.jit.load(fname, _extra_files=extra_files)
|
|
self.assertEqual(value, extra_files['foo'])
|
|
# results come back always as bytes
|
|
self.assertEqual(b"bar", extra_files['foo2'])
|
|
|
|
# Use torch.jit API
|
|
torch.jit.save(m, fname, _extra_files=expected_extra_files)
|
|
extra_files['foo'] = ''
|
|
torch.jit.load(fname, _extra_files=extra_files)
|
|
self.assertEqual(value, extra_files['foo'])
|
|
|
|
# Save to buffer.
|
|
buffer = io.BytesIO(m.save_to_buffer(_extra_files=expected_extra_files))
|
|
extra_files = {'foo': ''}
|
|
torch.jit.load(buffer, _extra_files=extra_files)
|
|
self.assertEqual(value, extra_files['foo'])
|
|
|
|
# Use torch.jit API
|
|
buffer = io.BytesIO()
|
|
torch.jit.save(m, buffer, _extra_files=expected_extra_files)
|
|
buffer.seek(0)
|
|
extra_files = {'foo': ''}
|
|
torch.jit.load(buffer, _extra_files=extra_files)
|
|
self.assertEqual(value, extra_files['foo'])
|
|
|
|
# Non-existent file 'bar'
|
|
with self.assertRaises(RuntimeError):
|
|
extra_files['bar'] = ''
|
|
torch.jit.load(buffer, _extra_files=extra_files)
|
|
|
|
def test_save_load_using_pathlib(self):
|
|
class MyMod(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, a):
|
|
return 2 * a
|
|
|
|
m = MyMod()
|
|
|
|
# Save then load.
|
|
with TemporaryFileName() as fname:
|
|
path = pathlib.Path(fname)
|
|
m.save(path)
|
|
m2 = torch.jit.load(path)
|
|
|
|
x = torch.tensor([1., 2., 3., 4.])
|
|
self.assertTrue(torch.equal(m(x), m2(x)))
|
|
|
|
def test_save_nonexit_file(self):
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x):
|
|
return 2 * x
|
|
|
|
script_module = torch.jit.script(Foo())
|
|
with self.assertRaises(RuntimeError):
|
|
script_module.save("NonExist/path/test.pt")
|
|
|
|
def test_save_namedtuple_input_only(self):
|
|
"""
|
|
Even if a NamedTuple is only used as an input argument, saving and
|
|
loading should work correctly.
|
|
"""
|
|
global FooTuple # see [local resolution in python]
|
|
|
|
class FooTuple(NamedTuple):
|
|
a: int
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, x: FooTuple) -> torch.Tensor:
|
|
return torch.tensor(3)
|
|
|
|
m_loaded = self.getExportImportCopy(torch.jit.script(MyModule()))
|
|
output = m_loaded(FooTuple(a=5))
|
|
self.assertEqual(output, torch.tensor(3))
|
|
|
|
def test_save_namedtuple_output_only(self):
|
|
"""
|
|
Even if a NamedTuple is only used as an output argument, saving and
|
|
loading should work correctly.
|
|
"""
|
|
global FooTuple # see [local resolution in python]
|
|
|
|
class FooTuple(NamedTuple):
|
|
a: int
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self) -> Optional[FooTuple]:
|
|
return None
|
|
|
|
m_loaded = self.getExportImportCopy(torch.jit.script(MyModule()))
|
|
output = m_loaded()
|
|
self.assertEqual(output, None)
|
|
|
|
def test_save_load_params_buffers_submodules(self):
|
|
"""
|
|
Check that parameters, buffers, and submodules are the same after loading.
|
|
"""
|
|
|
|
class Submodule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.add_module("submodule_a", Submodule())
|
|
self.register_parameter("parameter_a", torch.nn.Parameter(torch.randn(4)))
|
|
self.register_buffer("buffer", torch.randn(4))
|
|
self.t = torch.rand(4) # not buffer
|
|
|
|
self.parameter_b = torch.nn.Parameter(torch.randn(4))
|
|
self.submodule_b = Submodule()
|
|
|
|
m = TestModule()
|
|
m_loaded = self.getExportImportCopy(torch.jit.script(m))
|
|
|
|
# Check submodules.
|
|
self.assertEqual(len(list(m.named_modules())), len(list(m_loaded.named_modules())))
|
|
for m_s, loaded_s in zip(m.named_modules(), m_loaded.named_modules()):
|
|
m_name, _ = m_s
|
|
loaded_name, _ = loaded_s
|
|
self.assertEqual(m_name, loaded_name)
|
|
|
|
# Check parameters.
|
|
self.assertEqual(len(list(m.parameters())), len(list(m_loaded.parameters())))
|
|
for m_p, loaded_p in zip(m.parameters(), m_loaded.parameters()):
|
|
self.assertEqual(m_p, loaded_p)
|
|
|
|
# Check buffers.
|
|
self.assertEqual(len(list(m.named_buffers())), len(list(m_loaded.named_buffers())))
|
|
for m_b, loaded_b in zip(m.named_buffers(), m_loaded.named_buffers()):
|
|
m_name, m_buffer = m_b
|
|
loaded_name, loaded_buffer = loaded_b
|
|
self.assertEqual(m_name, loaded_name)
|
|
self.assertEqual(m_buffer, loaded_buffer)
|