mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Things that needed to be fixed: 1. Fix a bug with returning dict output type 2. Make pass_base work with map implementation 3. Fix subtle bug with dynamo not propagating "val" in node.meta 4. Add export_constraints field in ExportCase in ExportDB Pull Request resolved: https://github.com/pytorch/pytorch/pull/103861 Approved by: https://github.com/zhxchen17, https://github.com/ydwu4
274 lines
9.5 KiB
Python
274 lines
9.5 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
import unittest
|
|
|
|
import torch
|
|
import torch._dynamo as torchdynamo
|
|
from torch._export import export, dynamic_dim
|
|
from torch._export.trace import do_not_use_experimental_export
|
|
from torch._export.constraints import constrain_as_size, constrain_as_value
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
from functorch.experimental.control_flow import map
|
|
|
|
|
|
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
|
|
class TestExperimentalExport(TestCase):
|
|
@unittest.skip("TypeError: <lambda>() missing 1 required positional argument")
|
|
def test_export_simple_model_with_attr(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self, float_val):
|
|
super().__init__()
|
|
self.float_val = float_val
|
|
|
|
def forward(self, x):
|
|
y = x + self.float_val
|
|
return y.cos()
|
|
|
|
inp = (torch.ones(6, 4, requires_grad=True),)
|
|
mod = Foo(0.5)
|
|
|
|
exported_program = do_not_use_experimental_export(mod, inp)
|
|
self.assertEqual(exported_program.fw_module(*inp)[0], mod(*inp))
|
|
|
|
def test_export_simple_model(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self, float_val):
|
|
super().__init__()
|
|
self.float_val = float_val
|
|
|
|
def forward(self, x):
|
|
return x.cos()
|
|
|
|
inp = (torch.ones(6, 4, requires_grad=True),)
|
|
mod = Foo(0.5)
|
|
|
|
exported_program = do_not_use_experimental_export(mod, inp)
|
|
self.assertEqual(exported_program.fw_module(*inp)[0], mod(*inp))
|
|
|
|
@unittest.skip("TypeError: <lambda>() missing 1 required positional argument")
|
|
def test_export_simple_model_buffer_mutation(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self, float_val):
|
|
super().__init__()
|
|
self.register_buffer("buffer1", torch.ones(6, 1))
|
|
|
|
def forward(self, x):
|
|
self.buffer1.add_(2)
|
|
return x.cos() + self.buffer1.sin()
|
|
|
|
inp = (torch.ones(6, 4, requires_grad=True),)
|
|
mod = Foo(0.5)
|
|
|
|
exported_program = do_not_use_experimental_export(mod, inp)
|
|
mutated_buffer, output = exported_program.fw_module(*inp)
|
|
# TODO (tmanlaibaatar) enable this once we figure out
|
|
# how to do buffer mutation
|
|
# self.assertEqual(mutated_buffer.sum().item(), 30)
|
|
self.assertEqual(output, mod(*inp))
|
|
|
|
|
|
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
|
|
class TestDynamismExpression(TestCase):
|
|
def test_export_inline_constraints(self):
|
|
|
|
def f(x):
|
|
b = x.item()
|
|
constrain_as_size(b, min=2, max=5)
|
|
return torch.full((b, 1), 1)
|
|
|
|
inp = (torch.tensor([3]),)
|
|
ref = f(*inp)
|
|
|
|
gm = export(f, inp)
|
|
res = gm(*inp)
|
|
|
|
self.assertTrue(torchdynamo.utils.same(ref, res))
|
|
|
|
gm = make_fx(f, tracing_mode="symbolic")(*inp)
|
|
res = gm(*inp)
|
|
self.assertTrue(torchdynamo.utils.same(ref, res))
|
|
|
|
def test_export_constraints_error(self):
|
|
def invalid_size(x):
|
|
b = x.item()
|
|
constrain_as_size(b, min=0, max=5)
|
|
return torch.full((b, 1), 1)
|
|
|
|
inp = (torch.tensor([3]),)
|
|
with self.assertRaisesRegex(torchdynamo.exc.UserError, "Unable to set min size"):
|
|
export(invalid_size, inp)
|
|
|
|
def invalid_input_conflict_with_inline_constraints(x):
|
|
b = x.item()
|
|
constrain_as_size(b, min=2, max=5)
|
|
return torch.full((b, 1), 1)
|
|
|
|
inp = (torch.tensor([6]),)
|
|
with self.assertRaisesRegex(torchdynamo.exc.UserError, "Invalid value 6 for range"):
|
|
export(invalid_input_conflict_with_inline_constraints, inp)
|
|
|
|
def invalid_input_conflict_with_input_constraints(x):
|
|
return x + 1
|
|
|
|
inp = torch.zeros([3])
|
|
inp_constraints = [
|
|
dynamic_dim(inp, 0) > 5,
|
|
]
|
|
with self.assertRaisesRegex(torchdynamo.exc.UserError, "not in range"):
|
|
export(
|
|
invalid_input_conflict_with_input_constraints,
|
|
(inp,),
|
|
constraints=inp_constraints,
|
|
)
|
|
|
|
|
|
def conflicting_constraints(x):
|
|
b = x.item()
|
|
constrain_as_size(b, min=2, max=3)
|
|
constrain_as_size(b, min=4, max=5)
|
|
return torch.full((b, 1), 1)
|
|
|
|
inp = (torch.tensor([3]),)
|
|
|
|
with self.assertRaisesRegex(torchdynamo.exc.UserError, "Invalid ranges"):
|
|
export(conflicting_constraints, inp)
|
|
|
|
def test_export_assume_static_by_default(self):
|
|
def branch_on_shape(x: torch.Tensor):
|
|
if x.shape[0] == 4:
|
|
return x + 1
|
|
else:
|
|
return x
|
|
|
|
inp = (torch.rand(4, 5),)
|
|
|
|
# Being able to export means shape is preserved as static
|
|
export(branch_on_shape, inp)
|
|
|
|
|
|
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
|
|
class TestExport(TestCase):
|
|
def test_basic(self):
|
|
def f(x, y):
|
|
return x[0] + y
|
|
|
|
inp = ([torch.ones(1, 3)], torch.ones(1, 3))
|
|
exported_program = export(f, inp)
|
|
self.assertTrue(torch.allclose(exported_program(*inp), f(*inp)))
|
|
|
|
def test_raise_user_error_when_guard_on_data_dependent_operation(self):
|
|
def fn_ddo(x):
|
|
y = x.nonzero()
|
|
z = y.shape[0]
|
|
if z > 2:
|
|
return x.cos()
|
|
else:
|
|
return x.sin()
|
|
|
|
with self.assertRaisesRegex(
|
|
torchdynamo.exc.UserError,
|
|
"trying to get a value out of symbolic int"
|
|
):
|
|
_ = export(fn_ddo, (torch.tensor([2, 3, 5]),), constraints=None)
|
|
|
|
def test_if_functional(self):
|
|
def foo(x):
|
|
z = x + 4
|
|
z.add_(4)
|
|
y = z.view(x.shape)
|
|
return x.cos() + y.cos()
|
|
|
|
gm = export(foo, (torch.tensor([2, 3, 5]),), constraints=None)
|
|
|
|
view_count = 0
|
|
for node in gm.graph.nodes:
|
|
if node.op == "call_function" and node.target == torch.ops.aten.add_.Tensor:
|
|
# No more inplace mutation
|
|
self.assertNotEqual(
|
|
node.target,
|
|
torch.ops.aten.add_.Tensor,
|
|
"There shouldn't be any inplace mutation node in the graph."
|
|
)
|
|
if node.op == "call_function" and node.target == torch.ops.aten.view.default:
|
|
view_count += 1
|
|
|
|
# There should be nonzero view nodes in the graph
|
|
self.assertTrue(view_count > 0)
|
|
|
|
def test_export_mod_constraints(self):
|
|
class BasicDynamiShapeModel(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return x.view(x.shape[0] - 1, -1)
|
|
|
|
m = BasicDynamiShapeModel()
|
|
a = torch.randn(3, 4)
|
|
constraints = [3 <= dynamic_dim(a, 0), dynamic_dim(a, 1)]
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UserError,
|
|
(
|
|
"Some dynamic dimensions need to be specialized because "
|
|
"the constraints inferred for them are too complex to specify"
|
|
".*\n.*\\[0\\], which was marked dynamic, must be specialized to 3"
|
|
".*\n.*\\[1\\], which was marked dynamic, must be specialized to 4"
|
|
),
|
|
):
|
|
torch._export.export(m, (a,), constraints=constraints)
|
|
em = torch._export.export(m, (a,))
|
|
x = torch.randn(3, 5)
|
|
with self.assertRaisesRegex(RuntimeError, "\\[1\\] is specialized at 4"):
|
|
em(x)
|
|
|
|
def test_export_constrain_static(self):
|
|
def f(x, y):
|
|
b = x.item()
|
|
constrain_as_size(b, min=2, max=5)
|
|
c = y.dim()
|
|
constrain_as_value(c, min=1, max=3)
|
|
z = y[0:c]
|
|
return torch.empty((b, y.shape[0])), z
|
|
|
|
x = torch.tensor([3])
|
|
y = torch.randn([8, 8, 6])
|
|
example_inputs = (x, y)
|
|
constraints = [dynamic_dim(y, 0) >= 6, dynamic_dim(y, 0) <= 10]
|
|
with self.assertRaisesRegex(
|
|
torchdynamo.exc.UserError, "It appears that you're trying to set a constraint " +
|
|
"on a value which we evaluated to have a static value of 3. "
|
|
):
|
|
export(f, example_inputs, constraints)
|
|
|
|
def test_not_correct_dim(self):
|
|
def f(x):
|
|
return x.cos()
|
|
|
|
def g(x):
|
|
return x + 4
|
|
|
|
inp_for_f = torch.tensor(5)
|
|
with self.assertRaisesRegex(torchdynamo.exc.UserError, "Cannot mark 0-dimension tensors to be dynamic"):
|
|
constraints = [dynamic_dim(inp_for_f, 0)]
|
|
|
|
inp_for_f_mul_dim = torch.ones(5, 5)
|
|
with self.assertRaisesRegex(
|
|
torchdynamo.exc.UserError,
|
|
"Expected the dimension passed to dynamic_dim to be in the range \\[0:1\\]"
|
|
):
|
|
constraints = [dynamic_dim(inp_for_f_mul_dim, 2)]
|
|
|
|
inp_for_g = 4
|
|
with self.assertRaisesRegex(torchdynamo.exc.UserError, "Expected tensor as input to dynamic_dim"):
|
|
constraints = [dynamic_dim(inp_for_g, 0)]
|
|
|
|
def test_map(self):
|
|
def list_tensor_map(xs, y, z):
|
|
def body(x, y, z):
|
|
return x + y + z
|
|
|
|
return map(body, xs, y, z)
|
|
|
|
inps = (torch.ones(6, 4), torch.tensor(5), torch.tensor(4))
|
|
exported_program = export(list_tensor_map, inps)
|
|
self.assertTrue(torch.allclose(exported_program(*inps), list_tensor_map(*inps)))
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|