mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/74353 Repatched `d00de0d43598522b8f6ab2de553b6aaf6768faa5` by Nora Belrose (norabelrose). With following changes: * Register fake source of generated methods in linecache so that inspect.get_source will succeed. * this patching is only triggered if the given dataclass passed to torch.jit.script previously. Effectively we make this feature opt-in. ## Original Summary: Fixes https://github.com/pytorch/pytorch/issues/72901. Since we can't get access to the source code for synthesized magic methods on dataclasses, we have to synthesize our own versions. torch/jit/_dataclass_impls.py has the code that does this. What's supported Synthesized __init__, __eq__, and the comparison magic methods when order=True is set on the dataclass decorator Default values for fields __post_init__, including using InitVar fields inside of __post_init__, on Python 3.8+ Overriding __eq__ or any of the comparison magic methods to provide your own implementation What's not supported Default factory initializers for fields Frozen dataclasses InitVar on Python 3.7 __repr__ and __hash__ (these are actually implemented, but the TorchScript interpreter won't call them) Using the != operator on dataclasses inside TorchScript; this is because TorchScript requires that you implement __ne__ to use this operator, whereas in regular Python the != operator will resolve to the negation of whatever is returned by __eq__ if there's no __ne__. Dataclasses don't actually synthesize an __ne__ method for this reason. I've been toying with different ways to fix this but != is not working in this PR at the moment. Pull Request resolved: https://github.com/pytorch/pytorch/pull/74889 Test Plan: unittest Also run previously failed test: ``` buck test mode/dev-nosan //fblearner/flow/projects/fluent2/definition/transformers/contrib/faim/test:tests -- --exact 'fblearner/flow/projects/fluent2/definition/transformers/contrib/faim/test:tests - test_mixmatch_multiclass (fblearner.flow.projects.fluent2.definition.transformers.contrib.faim.test.faim_mixmatch_test.TestFaimTransformerMixMatch)' ``` passes Reviewed By: zhxchen17 Differential Revision: D35206262 Pulled By: qihqi Pull Request resolved: https://github.com/pytorch/pytorch/pull/76771 Approved by: https://github.com/seemethere
363 lines
12 KiB
Python
363 lines
12 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
|
from torch.testing import FileCheck
|
|
from torch import jit
|
|
from jit.test_module_interface import TestModuleInterface # noqa: F401
|
|
import os
|
|
import sys
|
|
import torch
|
|
import torch.testing._internal.jit_utils
|
|
import torch.nn as nn
|
|
from torch.testing._internal.common_utils import freeze_rng_state
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
|
|
if __name__ == '__main__':
|
|
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
|
"\tpython test/test_jit.py TESTNAME\n\n"
|
|
"instead.")
|
|
|
|
class TestMisc(JitTestCase):
|
|
def test_joined_str(self):
|
|
def func(x):
|
|
hello, test = "Hello", "test"
|
|
print(f"{hello + ' ' + test}, I'm a {test}")
|
|
print("format blank")
|
|
hi = 'hi'
|
|
print(f"stuff before {hi}")
|
|
print(f"{hi} stuff after")
|
|
return x + 1
|
|
|
|
x = torch.arange(4., requires_grad=True)
|
|
# TODO: Add support for f-strings in string parser frontend
|
|
# self.checkScript(func, [x], optimize=True, capture_output=True)
|
|
|
|
with self.capture_stdout() as captured:
|
|
out = func(x)
|
|
|
|
scripted = torch.jit.script(func)
|
|
with self.capture_stdout() as captured_script:
|
|
out_script = func(x)
|
|
|
|
self.assertEqual(out, out_script)
|
|
self.assertEqual(captured, captured_script)
|
|
|
|
def test_kwarg_support(self):
|
|
with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, "variable number of arguments"):
|
|
class M(torch.nn.Module):
|
|
def forward(self, *, n_tokens: int, device_name: str = 2):
|
|
pass
|
|
torch.jit.script(M())
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, *, n_tokens: int, device_name: str):
|
|
return n_tokens, device_name
|
|
|
|
sm = torch.jit.script(M())
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "missing value for argument 'n_tokens'"):
|
|
sm()
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "positional arg"):
|
|
sm(3, 'hello')
|
|
|
|
self.assertEqual(sm(n_tokens=3, device_name='hello'), (3, 'hello'))
|
|
|
|
def test_tuple_subscripted_assign(self):
|
|
with self.assertRaisesRegex(RuntimeError, "subscripted assignment"):
|
|
@torch.jit.script
|
|
def foo(a: Tuple[int, int]) -> None:
|
|
a[0] = a[1]
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "augmented assignment"):
|
|
@torch.jit.script
|
|
def bar(a: Tuple[int, int]) -> None:
|
|
a[0] += a[1]
|
|
|
|
def test_subexpression_List_Future(self):
|
|
|
|
@torch.jit.script
|
|
def fn(x: List[torch.jit.Future[int]]) -> torch.jit.Future[int]:
|
|
return x[0]
|
|
|
|
FileCheck().check('Future[int]').check('Future[int]').run(fn.graph)
|
|
|
|
def test_subexpression_Future_annotate(self):
|
|
@torch.jit.script
|
|
def fn() -> torch.jit.Future[int]:
|
|
x: List[torch.jit.Future[int]] = []
|
|
return x[0]
|
|
|
|
FileCheck().check("Future[int][]").run(fn.graph)
|
|
|
|
def test_future_isinstance(self):
|
|
@torch.jit.script
|
|
def fn(x: Any) -> torch.jit.Future[int]:
|
|
assert isinstance(x, jit.Future[int])
|
|
return x
|
|
|
|
FileCheck().check("Future[int]").run(fn.graph)
|
|
|
|
def test_str_refine_any(self):
|
|
def forward(x: Any) -> str:
|
|
if isinstance(x, str):
|
|
return x
|
|
return "foo"
|
|
forward = torch.jit.script(forward)
|
|
self.assertEqual(forward(1), "foo")
|
|
self.assertEqual(forward("bar"), "bar")
|
|
|
|
def test_subexpression_Tuple_int_int_Future(self):
|
|
|
|
@torch.jit.script
|
|
def fn(x: Tuple[int, int, torch.jit.Future[int]]) -> Tuple[int, torch.jit.Future[int]]:
|
|
return x[0], x[2]
|
|
|
|
FileCheck().check('(int, int, Future[int])').check('(int, Future[int])').run(fn.graph)
|
|
|
|
def test_subexpression_Dict_int_Future(self):
|
|
|
|
@torch.jit.script
|
|
def fn(x: Dict[int, torch.jit.Future[int]], y: int) -> torch.jit.Future[int]:
|
|
return x[y]
|
|
|
|
FileCheck().check('Dict(int, Future(int))').check('Future[int]').run(fn.graph)
|
|
|
|
def test_subexpression_Optional(self):
|
|
|
|
@torch.jit.script
|
|
def fn(x: Optional[Dict[int, torch.jit.Future[int]]]) -> Optional[torch.jit.Future[int]]:
|
|
if x is not None:
|
|
return x[0]
|
|
else:
|
|
return None
|
|
|
|
FileCheck().check('Dict(int, Future(int))?').run(fn.graph)
|
|
|
|
def test_if_returning_any(self):
|
|
"""
|
|
Check that an if statement can return different
|
|
types early from each branch when the return
|
|
type of the function is Any.
|
|
"""
|
|
def if_function(inp: torch.Tensor) -> Any:
|
|
if inp.shape[0] == 1:
|
|
return inp * inp
|
|
else:
|
|
return "str"
|
|
|
|
self.checkScript(if_function, (torch.randn(5),))
|
|
|
|
def test_hacked_twin(self):
|
|
|
|
def gen_data():
|
|
with freeze_rng_state():
|
|
return torch.randn(10), torch.randint(10, (20,)), torch.randn(20)
|
|
|
|
input, index, value, = gen_data()
|
|
input1, index1, value1, = gen_data()
|
|
out1 = torch.ops.aten.index_put.hacked_twin(input, [index], value, accumulate=False)
|
|
out2 = torch.index_put(input1, [index1], value1, accumulate=False)
|
|
self.assertEqual(out1, out2)
|
|
|
|
torch.ops.aten.index_put_.hacked_twin(input, [index], value, accumulate=False)
|
|
torch.index_put_(input1, [index1], value1, accumulate=False)
|
|
self.assertEqual(input, input1)
|
|
|
|
def test_export_opnames_interface(self):
|
|
|
|
@torch.jit.interface
|
|
class OneTwoModule(nn.Module):
|
|
def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
pass
|
|
|
|
def two(self, x: torch.Tensor) -> torch.Tensor:
|
|
pass
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
pass
|
|
|
|
class FooMod(nn.Module):
|
|
def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
return x + y
|
|
|
|
def two(self, x: torch.Tensor) -> torch.Tensor:
|
|
return 2 * x
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.one(self.two(x), x)
|
|
|
|
class BarMod(nn.Module):
|
|
def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
return x * y
|
|
|
|
def two(self, x: torch.Tensor) -> torch.Tensor:
|
|
return 2 / x
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.two(self.one(x, x))
|
|
|
|
make_global(OneTwoModule)
|
|
|
|
class M(nn.Module):
|
|
sub : OneTwoModule
|
|
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.sub = BarMod()
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.sub.forward(x)
|
|
|
|
def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor):
|
|
return mod_list[0].forward(x) + mod_list[1].forward(x)
|
|
|
|
torch._C._enable_mobile_interface_call_export()
|
|
scripted_M_mod = torch.jit.script(M())
|
|
self.assertTrue(set(['aten::mul.Scalar', 'aten::mul.Tensor', 'aten::reciprocal']).issubset(
|
|
set(torch.jit.export_opnames(scripted_M_mod))))
|
|
|
|
scripted_M_mod.sub = torch.jit.script(FooMod())
|
|
self.assertTrue(set(['aten::add.Tensor', 'aten::mul.Scalar']).issubset(
|
|
set(torch.jit.export_opnames(scripted_M_mod))))
|
|
|
|
def test_math_inf(self):
|
|
from math import inf
|
|
|
|
def foo():
|
|
return inf
|
|
|
|
self.checkScript(foo, ())
|
|
|
|
def test_list_literal_infer(self):
|
|
def expects_intlist(x: List[int]):
|
|
x.append(3)
|
|
return x
|
|
|
|
def foo():
|
|
return expects_intlist([])
|
|
|
|
self.checkScript(foo, ())
|
|
|
|
def annotated_list_fail():
|
|
return expects_intlist(torch.jit.annotate([], List[Tensor]))
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
torch.jit.script(annotated_list_fail)
|
|
|
|
def non_temporary_fail():
|
|
a = []
|
|
return expects_intlist(a)
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
torch.jit.script(non_temporary_fail)
|
|
|
|
|
|
@torch.jit.script
|
|
def test_return():
|
|
return []
|
|
|
|
FileCheck().check("Tensor[] = prim::ListConstruct").run(test_return.graph)
|
|
|
|
def test_legacy_tensor_constructor(self):
|
|
# testing PyObject overload
|
|
def test_all_dtypes():
|
|
return (
|
|
torch.BoolTensor([2]),
|
|
torch.LongTensor([3]),
|
|
torch.ByteTensor([4]),
|
|
torch.CharTensor([5]),
|
|
torch.DoubleTensor([6]),
|
|
torch.FloatTensor([7]),
|
|
torch.IntTensor([8]),
|
|
torch.ShortTensor([1]),
|
|
torch.HalfTensor([1]),
|
|
)
|
|
|
|
self.checkScript(test_all_dtypes, ())
|
|
|
|
# now test empty overload
|
|
def empty_overload():
|
|
return torch.LongTensor(2, 3, 4)
|
|
|
|
eager = empty_overload()
|
|
jit = torch.jit.script(empty_overload)()
|
|
eager[:] = 1
|
|
jit[:] = 1
|
|
self.assertEqual(eager, jit)
|
|
|
|
def no_inputs():
|
|
return torch.DoubleTensor()
|
|
|
|
self.checkScript(no_inputs, ())
|
|
|
|
# bad schema
|
|
def multiple_args():
|
|
return torch.LongTensor(1, [2])
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "multiple positional arguments that were not all integers"):
|
|
torch.jit.script(multiple_args)
|
|
|
|
# kwarg bad schema
|
|
def bad_kwarg():
|
|
return torch.LongTensor(hello="1")
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "hello"):
|
|
torch.jit.script(bad_kwarg)
|
|
|
|
|
|
def test_broadcasting_list(self):
|
|
"""
|
|
Test BroadcastingList and torch.nn._size_N_t alias
|
|
"""
|
|
from torch._jit_internal import BroadcastingList2
|
|
from torch.nn.common_types import _size_2_t
|
|
|
|
def sum_i(x: _size_2_t) -> int:
|
|
return x[0] + x[1]
|
|
|
|
def sum_f(x: BroadcastingList2[float]) -> float:
|
|
return x[0] + x[1]
|
|
|
|
self.assertTrue(torch.jit.script(sum_i)(4) == 8)
|
|
self.assertTrue(torch.jit.script(sum_f)(4.5) == 9.)
|
|
|
|
def test_parse_ir_annotate(self):
|
|
ir = """
|
|
graph():
|
|
%3 : int[] = prim::Constant[value=annotate(List[int], [])]()
|
|
return (%3)
|
|
"""
|
|
graph = torch._C.parse_ir(ir, True)
|
|
func = torch._C._create_function_from_graph("forward", graph)
|
|
ret = func()
|
|
self.assertTrue(ret == [])
|
|
|
|
def test_parse_ir_single_element_tensor_positive(self):
|
|
ir = """
|
|
graph():
|
|
%7 : Long(1, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value={0}]()
|
|
return (%7)
|
|
"""
|
|
graph = torch._C.parse_ir(ir, True)
|
|
func = torch._C._create_function_from_graph("forward", graph)
|
|
ret = func()
|
|
self.assertTrue(ret.numel() == 1)
|
|
self.assertTrue(len(ret.size()) == 1)
|
|
|
|
def test_parse_ir_single_element_tensor_negative(self):
|
|
ir = """
|
|
graph():
|
|
%7 : Long(1, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value={-17}]()
|
|
return (%7)
|
|
"""
|
|
graph = torch._C.parse_ir(ir, True)
|
|
func = torch._C._create_function_from_graph("forward", graph)
|
|
ret = func()
|
|
self.assertTrue(ret.numel() == 1)
|
|
self.assertTrue(len(ret.size()) == 1)
|