mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Change _dynamo.export to be export(f)(*args, **kwargs) (#106109)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/106109 Approved by: https://github.com/voznesenskym
This commit is contained in:
parent
5cbd3fc412
commit
7b9d250f06
16 changed files with 521 additions and 503 deletions
|
|
@ -23,7 +23,7 @@ with torch._dynamo.config.patch(dynamic_shapes=False):
|
|||
torch._dynamo.reset()
|
||||
|
||||
with torch.no_grad():
|
||||
module, _ = torch._dynamo.export(Net().cuda(), x, y)
|
||||
module, _ = torch._dynamo.export(Net().cuda())(x, y)
|
||||
lib_path = torch._inductor.aot_compile(module, [x, y])
|
||||
|
||||
shutil.copy(lib_path, "libaot_inductor_output.so")
|
||||
|
|
|
|||
|
|
@ -154,7 +154,7 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
|
|||
real = mod(rx)
|
||||
|
||||
# Run it in export
|
||||
graph, _ = torch._dynamo.export(mod, rx)
|
||||
graph, _ = torch._dynamo.export(mod)(rx)
|
||||
|
||||
# Run exported graph with AOT
|
||||
self.assertTrue(torch._dynamo.testing.same(real, graph(rx)))
|
||||
|
|
@ -185,7 +185,7 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
|
|||
real = mod(x, y)
|
||||
|
||||
# Run it in export
|
||||
graph, _ = torch._dynamo.export(mod, x, y)
|
||||
graph, _ = torch._dynamo.export(mod)(x, y)
|
||||
|
||||
# Assert equal
|
||||
self.assertTrue(torch._dynamo.testing.same(real, graph(x, y)))
|
||||
|
|
|
|||
|
|
@ -238,7 +238,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
|
|||
real_device = real.device
|
||||
real_dtype = real.dtype
|
||||
|
||||
graph, guards = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
|
||||
graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
|
||||
exported = graph(torch.tensor([0.5]))
|
||||
self.assertEqual(exported.device, real_device)
|
||||
self.assertEqual(exported.dtype, real_dtype)
|
||||
|
|
@ -263,7 +263,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
|
|||
real_device = real.device
|
||||
real_dtype = real.dtype
|
||||
|
||||
graph, _ = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
|
||||
graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
|
||||
exported = graph(torch.tensor([0.5]))
|
||||
self.assertEqual(exported.device, real_device)
|
||||
self.assertEqual(exported.dtype, real_dtype)
|
||||
|
|
@ -347,7 +347,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
|
|||
real_device = real.device
|
||||
real_dtype = real.dtype
|
||||
|
||||
graph, guards = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
|
||||
graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
|
||||
exported = graph(torch.tensor([0.5]))
|
||||
self.assertEqual(exported.device, real_device)
|
||||
self.assertEqual(exported.dtype, real_dtype)
|
||||
|
|
@ -521,7 +521,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
|
|||
real_device = real.device
|
||||
real_dtype = real.dtype
|
||||
|
||||
graph, guards = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
|
||||
graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
|
||||
exported = graph(torch.tensor([0.5]))
|
||||
self.assertEqual(exported.device, real_device)
|
||||
self.assertEqual(exported.dtype, real_dtype)
|
||||
|
|
@ -547,7 +547,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
|
|||
real_device = real.device
|
||||
real_dtype = real.dtype
|
||||
|
||||
graph, guards = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
|
||||
graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
|
||||
exported = graph(torch.tensor([0.5]))
|
||||
self.assertEqual(exported.device, real_device)
|
||||
self.assertEqual(exported.dtype, real_dtype)
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -10,11 +10,11 @@ from torch.testing._internal.common_utils import IS_FBCODE
|
|||
class MutationExportTests(torch._dynamo.test_case.TestCase):
|
||||
def check_failure_on_export(self, mod, *args):
|
||||
with self.assertRaises(AssertionError):
|
||||
torch._dynamo.export(mod, *args)
|
||||
torch._dynamo.export(mod)(*args)
|
||||
|
||||
def check_same_with_export(self, mod, arg):
|
||||
real_result = mod(arg)
|
||||
graph, _ = torch._dynamo.export(mod, arg)
|
||||
graph, _ = torch._dynamo.export(mod)(arg)
|
||||
result = graph(arg)
|
||||
self.assertTrue(torch._dynamo.utils.same(result, real_result))
|
||||
|
||||
|
|
|
|||
|
|
@ -3243,7 +3243,7 @@ def fn():
|
|||
def f(x):
|
||||
return 1 + torch._shape_as_tensor(x)[0]
|
||||
|
||||
gm, _ = torch._dynamo.export(f, torch.ones(6))
|
||||
gm, _ = torch._dynamo.export(f)(torch.ones(6))
|
||||
|
||||
input_one_dim = torch.ones(6)
|
||||
input_two_dims = torch.ones(7, 4)
|
||||
|
|
@ -3583,8 +3583,8 @@ def fn():
|
|||
def f(pred, pred2, x):
|
||||
return cond(pred, true_fn, false_fn, [pred2, x])
|
||||
|
||||
graph, guard = torch._dynamo.export(
|
||||
f, torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25])
|
||||
graph, guard = torch._dynamo.export(f)(
|
||||
torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25])
|
||||
)
|
||||
true_true_sin = graph(
|
||||
torch.tensor(True), torch.tensor(True), torch.tensor([0.25, 0.25])
|
||||
|
|
@ -3622,8 +3622,8 @@ def fn():
|
|||
def f(pred, x):
|
||||
return cond(pred, true_fn, false_fn, [x])
|
||||
|
||||
graph, guard = torch._dynamo.export(
|
||||
f, torch.tensor(False), torch.tensor([0.25, 0.25])
|
||||
graph, guard = torch._dynamo.export(f)(
|
||||
torch.tensor(False), torch.tensor([0.25, 0.25])
|
||||
)
|
||||
true_mirror = graph(torch.tensor(True), torch.tensor([0.25, 0.25]))
|
||||
self.assertTrue(same(torch.tensor([0.25, 0.25]), true_mirror))
|
||||
|
|
@ -3891,7 +3891,7 @@ def fn():
|
|||
return self.mod[0](x)
|
||||
|
||||
m = Mod()
|
||||
graph, _ = torch._dynamo.export(m, torch.randn(3, 3))
|
||||
graph, _ = torch._dynamo.export(m)(torch.randn(3, 3))
|
||||
|
||||
def test_nn_sequential_invocation(self):
|
||||
with freeze_rng_state():
|
||||
|
|
@ -3913,7 +3913,7 @@ def fn():
|
|||
m = TestModel()
|
||||
x = torch.rand((2, 2))
|
||||
real = m(x)
|
||||
graph, _ = torch._dynamo.export(m, x)
|
||||
graph, _ = torch._dynamo.export(m)(x)
|
||||
dynamo_result = graph(x)
|
||||
self.assertTrue(same(real, dynamo_result))
|
||||
|
||||
|
|
@ -3937,7 +3937,7 @@ def fn():
|
|||
m = TestModel()
|
||||
x = torch.rand((2, 2))
|
||||
real = m(x)
|
||||
graph, _ = torch._dynamo.export(m, x)
|
||||
graph, _ = torch._dynamo.export(m)(x)
|
||||
dynamo_result = graph(x)
|
||||
self.assertTrue(same(real, dynamo_result))
|
||||
|
||||
|
|
@ -4724,7 +4724,7 @@ def fn():
|
|||
b.tag = "b"
|
||||
b.frog = "ribbit"
|
||||
|
||||
exported = torch._dynamo.export(foo, a, b)
|
||||
exported = torch._dynamo.export(foo)(a, b)
|
||||
out_graph = exported[0]
|
||||
|
||||
nodes = list(out_graph.graph.nodes)
|
||||
|
|
@ -4772,7 +4772,7 @@ def fn():
|
|||
state[0].tag = "STATE_0"
|
||||
state[1].tag = "HMMM"
|
||||
|
||||
exported = torch._dynamo.export(pre_attention_state_ops, i, mems, state)
|
||||
exported = torch._dynamo.export(pre_attention_state_ops)(i, mems, state)
|
||||
out_graph = exported[0]
|
||||
|
||||
nodes = list(out_graph.graph.nodes)
|
||||
|
|
|
|||
|
|
@ -1410,7 +1410,7 @@ class NNModuleTests(torch._dynamo.test_case.TestCase):
|
|||
mod = ModuleSpecialFwd()
|
||||
rx = torch.randn([3, 10, 10])
|
||||
real = mod(rx)
|
||||
graph, _ = torch._dynamo.export(mod, rx)
|
||||
graph, _ = torch._dynamo.export(mod)(rx)
|
||||
self.assertTrue(torch._dynamo.testing.same(real, graph(rx)))
|
||||
|
||||
def test_conv_call_forward_directly(self):
|
||||
|
|
|
|||
|
|
@ -1188,7 +1188,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertEqual(cnt.frame_count, 2)
|
||||
self.assertEqual(cnt.op_count, 3) # rand, rand
|
||||
try:
|
||||
graph, _ = torch._dynamo.export(fn)
|
||||
graph, _ = torch._dynamo.export(fn)()
|
||||
# See https://github.com/pytorch/pytorch/pull/87490
|
||||
self.fail("unexpected export success")
|
||||
except torch._dynamo.exc.Unsupported:
|
||||
|
|
@ -2650,11 +2650,11 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertEqual(cnt.op_count, 6)
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
|
||||
exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5]))
|
||||
exported, _ = torch._dynamo.export(f)(torch.Tensor([3, 4, 5]))
|
||||
self.assertTrue(same(exported(*args), f(*args)))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "First dim need to be 3"):
|
||||
exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5]))
|
||||
exported, _ = torch._dynamo.export(f)(torch.Tensor([4, 4, 5]))
|
||||
|
||||
def test_not_rewrite_assert_for_other_errors(self):
|
||||
def f(x):
|
||||
|
|
@ -2686,11 +2686,11 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
return x.cos() + b
|
||||
|
||||
args = (torch.Tensor([3, 4, 5]),)
|
||||
exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5]))
|
||||
exported, _ = torch._dynamo.export(f)(torch.Tensor([3, 4, 5]))
|
||||
self.assertTrue(same(exported(*args), f(*args)))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "assertion error"):
|
||||
exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5]))
|
||||
exported, _ = torch._dynamo.export(f)(torch.Tensor([4, 4, 5]))
|
||||
|
||||
def test_rewrite_assert_with_non_string_msg(self):
|
||||
def f(x):
|
||||
|
|
@ -2718,7 +2718,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
return x.cos() + b
|
||||
|
||||
args = (torch.Tensor([3, 4, 5]),)
|
||||
exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5]))
|
||||
exported, _ = torch._dynamo.export(f)(torch.Tensor([3, 4, 5]))
|
||||
self.assertTrue(same(exported(*args), f(*args)))
|
||||
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
|
|
@ -2728,7 +2728,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertEqual(cnt.op_count, 3)
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
|
||||
exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5]))
|
||||
exported, _ = torch._dynamo.export(f)(torch.Tensor([4, 4, 5]))
|
||||
self.assertTrue(same(exported(*args), f(*args)))
|
||||
|
||||
def test_size_typematch(self):
|
||||
|
|
@ -2885,7 +2885,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
inp = torch.randn(6, 5)
|
||||
|
||||
gm, _ = torch._dynamo.export(f, torch.randn(4, 5), aten_graph=True)
|
||||
gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.randn(4, 5))
|
||||
self.assertEqual(gm(inp).shape, f(inp).shape)
|
||||
|
||||
@torch._dynamo.config.patch("specialize_int", False)
|
||||
|
|
@ -3058,9 +3058,10 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
gm, _ = torch._dynamo.export(
|
||||
f,
|
||||
aten_graph=True,
|
||||
)(
|
||||
torch.zeros(6, 4),
|
||||
torch.tensor(1),
|
||||
aten_graph=True,
|
||||
)
|
||||
self.assertEqual(
|
||||
f(torch.zeros(6, 4), torch.tensor(1)),
|
||||
|
|
@ -3158,8 +3159,9 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
gm, _ = torch._dynamo.export(
|
||||
f,
|
||||
torch.zeros(6, 4),
|
||||
aten_graph=True,
|
||||
)(
|
||||
torch.zeros(6, 4),
|
||||
)
|
||||
|
||||
self.assertEqual(f(torch.ones(8, 4)), gm(torch.ones(8, 4)))
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class TestSourceMatcher(JitTestCase):
|
|||
return x
|
||||
|
||||
inputs = (torch.randn(3, 3),)
|
||||
gm, _ = torch._dynamo.export(M(), *inputs, aten_graph=True)
|
||||
gm, _ = torch._dynamo.export(M(), aten_graph=True)(*inputs)
|
||||
gm.graph.eliminate_dead_code()
|
||||
|
||||
module_partitions = get_source_partitions(gm.graph, [torch.nn.Linear, torch.nn.ReLU])
|
||||
|
|
@ -69,7 +69,7 @@ class TestSourceMatcher(JitTestCase):
|
|||
return self.maxpool(self.relu(z))
|
||||
|
||||
inputs = (torch.randn(1, 3, 256, 256),)
|
||||
gm, _ = torch._dynamo.export(M(torch.ones(1, 16, 256, 256)), *inputs, aten_graph=True)
|
||||
gm, _ = torch._dynamo.export(M(torch.ones(1, 16, 256, 256)), aten_graph=True)(*inputs)
|
||||
gm.graph.eliminate_dead_code()
|
||||
|
||||
module_partitions = get_source_partitions(gm.graph, [torch.nn.Conv2d, torch.nn.ReLU, torch.nn.MaxPool2d])
|
||||
|
|
@ -111,7 +111,7 @@ class TestSourceMatcher(JitTestCase):
|
|||
return x
|
||||
|
||||
inputs = (torch.randn(1, 3, 5, 5), torch.rand(3, 3, 3, 3), torch.rand(3))
|
||||
gm, _ = torch._dynamo.export(M(), *inputs, aten_graph=True)
|
||||
gm, _ = torch._dynamo.export(M(), aten_graph=True)(*inputs)
|
||||
gm.graph.eliminate_dead_code()
|
||||
|
||||
module_partitions = get_source_partitions(gm.graph, [torch.nn.functional.conv2d])
|
||||
|
|
@ -135,7 +135,7 @@ class TestSourceMatcher(JitTestCase):
|
|||
return x
|
||||
|
||||
inputs = (torch.randn(1, 5), torch.rand((5, 5)), torch.zeros(5))
|
||||
gm, _ = torch._dynamo.export(M(), *inputs, aten_graph=True)
|
||||
gm, _ = torch._dynamo.export(M(), aten_graph=True)(*inputs)
|
||||
gm.graph.eliminate_dead_code()
|
||||
|
||||
module_partitions = get_source_partitions(gm.graph, [torch.nn.functional.linear, torch.nn.functional.relu])
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ class TestFxPasses(common_utils.TestCase):
|
|||
x = torch.randn(3)
|
||||
y = torch.randn(3)
|
||||
z = torch.randn(3)
|
||||
gm, _ = torch._dynamo.export(func, x, y, z)
|
||||
gm, _ = torch._dynamo.export(func)(x, y, z)
|
||||
torch._dynamo.reset()
|
||||
|
||||
# Purposely name the nodes in a way that will cause a recursive collision later.
|
||||
|
|
@ -44,7 +44,7 @@ class TestFxPasses(common_utils.TestCase):
|
|||
x = torch.randn(3)
|
||||
y = torch.randn(3)
|
||||
z = torch.randn(3)
|
||||
gm, _ = torch._dynamo.export(func, x, y, z)
|
||||
gm, _ = torch._dynamo.export(func)(x, y, z)
|
||||
torch._dynamo.reset()
|
||||
|
||||
# Run `set_node_name` and verify that the names are correct.
|
||||
|
|
|
|||
|
|
@ -833,7 +833,7 @@ class FlattenInputOutputSignature(torch.fx.interpreter.Transformer):
|
|||
|
||||
def export(
|
||||
f: Callable[..., Any],
|
||||
*args,
|
||||
*extra_args,
|
||||
aten_graph: bool = False,
|
||||
pre_dispatch: bool = False,
|
||||
decomposition_table: Optional[
|
||||
|
|
@ -843,16 +843,14 @@ def export(
|
|||
constraints: Optional[List[Constraint]] = None,
|
||||
assume_static_by_default: bool = False,
|
||||
fake_mode: fake_tensor.FakeTensorMode = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.fx.GraphModule, Set[_guards.Guard]]:
|
||||
**extra_kwargs,
|
||||
) -> Callable[..., Tuple[torch.fx.GraphModule, Set[_guards.Guard]]]:
|
||||
"""
|
||||
Export an input function f to a format that can be executed outside of PyTorch using the FX graph.
|
||||
|
||||
Args:
|
||||
f (callable): A PyTorch function to be exported.
|
||||
|
||||
*args: Variable length argument list to be passed to the function f.
|
||||
|
||||
aten_graph (bool): If True, exports a graph with ATen operators.
|
||||
If False, exports a graph with Python operators. Default is False.
|
||||
|
||||
|
|
@ -872,10 +870,8 @@ def export(
|
|||
Useful during symbolic tracing, when user input is already fakefied. Implies free fake tensors
|
||||
are allowed on `make_fx`. `fake_mode` must contain a valid (not None) `shape_env` instance.
|
||||
|
||||
**kwargs: Arbitrary keyword arguments to be passed to the function f.
|
||||
|
||||
Returns:
|
||||
A tuple of (graph, guards)
|
||||
A function that given args and kwargs, returns a tuple of (graph, guards)
|
||||
Graph: An FX graph representing the execution of the input PyTorch function with the provided arguments and options.
|
||||
Guards: The guards we accumulated during tracing f above
|
||||
|
||||
|
|
@ -887,299 +883,329 @@ def export(
|
|||
|
||||
Note - this headerdoc was authored by ChatGPT, with slight modifications by the author.
|
||||
"""
|
||||
check_if_dynamo_supported()
|
||||
torch._C._log_api_usage_once("torch._dynamo.export")
|
||||
if decomposition_table is not None:
|
||||
assert (
|
||||
aten_graph
|
||||
), "Specifying a decomposition_table table or tracing mode is illegal without setting aten_graph=True"
|
||||
if pre_dispatch:
|
||||
assert aten_graph, "pre_dispatch=True can only be used when aten_graph=True"
|
||||
f = innermost_fn(f)
|
||||
call_to_inspect = f.forward if isinstance(f, torch.nn.Module) else f
|
||||
original_signature = inspect.signature(call_to_inspect)
|
||||
# Deal with "local variable referenced before assignment"
|
||||
_fake_mode = fake_mode
|
||||
_f = f
|
||||
_assume_static_by_default = assume_static_by_default
|
||||
|
||||
assert (
|
||||
not fake_mode or fake_mode.shape_env is not None
|
||||
), "The specified fake_mode must contain a valid shape_env"
|
||||
graph = None
|
||||
out_guards = None
|
||||
graph_captured_input = None
|
||||
graph_captured_result: Optional[Tuple[torch.Tensor, ...]] = None
|
||||
fake_mode = fake_mode or _guards.detect_fake_mode(args)
|
||||
_allow_fake_constant: bool = (
|
||||
fake_mode is not None
|
||||
) # Allow fake constants during symbolic tracing
|
||||
|
||||
def produce_matching(source_args, candidate_args):
|
||||
matched_elements_positions = []
|
||||
dict_of_source_args = dict()
|
||||
for i in range(0, len(source_args)):
|
||||
element_id = id(source_args[i])
|
||||
dict_of_source_args[element_id] = i
|
||||
|
||||
for i in range(0, len(candidate_args)):
|
||||
arg = candidate_args[i]
|
||||
# 1-element tensor arg can be unspec int/float
|
||||
if isinstance(arg, torch.Tensor) and torch.numel(arg) == 1:
|
||||
if id(arg) in dict_of_source_args:
|
||||
matched_elements_positions.append(dict_of_source_args[id(arg)])
|
||||
elif id(arg.item()) in dict_of_source_args:
|
||||
matched_elements_positions.append(
|
||||
dict_of_source_args[id(arg.item())]
|
||||
)
|
||||
else:
|
||||
raise AssertionError(
|
||||
"Dynamo input/output is not consistent with traced input/output"
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
id(arg) in dict_of_source_args
|
||||
), "Dynamo input and output is a strict subset of traced input/output"
|
||||
matched_elements_positions.append(dict_of_source_args[id(arg)])
|
||||
|
||||
return matched_elements_positions
|
||||
|
||||
def guard_export_print(guards: Set[_guards.Guard]):
|
||||
nonlocal out_guards
|
||||
assert out_guards is None, "whole graph export entails exactly one guard export"
|
||||
out_guards = guards
|
||||
|
||||
example_inputs = []
|
||||
|
||||
def dynamo_normalization_capturing_compiler(
|
||||
gm: torch.fx.GraphModule, inner_example_inputs
|
||||
):
|
||||
nonlocal graph
|
||||
assert (
|
||||
graph is None
|
||||
), "Tried to emit a second graph during export. Tracing through 'f' must produce a single graph."
|
||||
graph = gm
|
||||
|
||||
nonlocal fake_mode, example_inputs
|
||||
fake_mode = fake_mode or _guards.detect_fake_mode(inner_example_inputs)
|
||||
example_inputs = inner_example_inputs
|
||||
|
||||
def result_capturing_wrapper(*graph_inputs):
|
||||
nonlocal graph_captured_result
|
||||
nonlocal graph_captured_input
|
||||
|
||||
graph_captured_input = graph_inputs
|
||||
assert graph is not None
|
||||
graph_captured_result = graph(*graph_inputs)
|
||||
return graph_captured_result
|
||||
|
||||
return result_capturing_wrapper
|
||||
|
||||
flat_args, in_spec = pytree.tree_flatten((args, kwargs))
|
||||
|
||||
remove_from_cache(f)
|
||||
constraint_violation_error = None
|
||||
if tracing_mode != "symbolic":
|
||||
assume_static_by_default = True
|
||||
with patch(f"{__name__}.most_recent_backend", None), config.patch(
|
||||
specialize_int=True,
|
||||
assume_static_by_default=assume_static_by_default,
|
||||
automatic_dynamic_shapes=False,
|
||||
capture_dynamic_output_shape_ops=True,
|
||||
capture_scalar_outputs=True,
|
||||
), torch._guards.export_fake_mode(fake_mode):
|
||||
opt_f = optimize_assert(
|
||||
dynamo_normalization_capturing_compiler,
|
||||
hooks=Hooks(
|
||||
guard_export_fn=guard_export_print,
|
||||
guard_fail_fn=None,
|
||||
),
|
||||
export=True,
|
||||
export_constraints=constraints,
|
||||
)(f)
|
||||
# TODO(voz): We may have instances of `f` that mutate inputs, we should track sideffects and reject.
|
||||
try:
|
||||
result_traced = opt_f(*args, **kwargs)
|
||||
except ConstraintViolationError as e:
|
||||
constraint_violation_error = e
|
||||
remove_from_cache(f)
|
||||
|
||||
if (
|
||||
(shape_env := getattr(fake_mode, "shape_env", None)) is not None
|
||||
and (dim_constraints := shape_env.dim_constraints) is not None
|
||||
and not skipfiles.check(inspect.getsourcefile(call_to_inspect))
|
||||
):
|
||||
dim_constraints.solve()
|
||||
msg = dim_constraints.prettify_results(original_signature)
|
||||
forced_specializations = dim_constraints.forced_specializations()
|
||||
if forced_specializations:
|
||||
msg = (
|
||||
"Some dynamic dimensions need to be specialized because "
|
||||
"the constraints inferred for them are too complex to specify.\n"
|
||||
f"{forced_specializations}\n{msg}"
|
||||
)
|
||||
if constraint_violation_error:
|
||||
constraint_violation_error.args = (
|
||||
constraint_violation_error.args[0] + msg,
|
||||
)
|
||||
else:
|
||||
if forced_specializations:
|
||||
constraint_violation_error = ConstraintViolationError(msg)
|
||||
else:
|
||||
log.info(
|
||||
"Summary of dimension constraints:%s",
|
||||
msg,
|
||||
)
|
||||
|
||||
# Error if we have any constraints on static values
|
||||
for k in shape_env.var_to_range.keys():
|
||||
if isinstance(k, sympy.Integer):
|
||||
constraint_violation_error = ConstraintViolationError(
|
||||
f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
|
||||
"It appears that you're trying to set a constraint on a "
|
||||
f"value which we evaluated to have a static value of {k}. "
|
||||
"Scroll up to see where this constraint was set."
|
||||
)
|
||||
if constraint_violation_error:
|
||||
raise constraint_violation_error
|
||||
|
||||
assert (
|
||||
graph is not None
|
||||
), "Failed to produce a graph during tracing. Tracing through 'f' must produce a single graph."
|
||||
assert out_guards is not None, "Failed to produce guards during tracing"
|
||||
assert fake_mode is not None
|
||||
|
||||
matched_input_elements_positions = produce_matching(flat_args, graph_captured_input)
|
||||
|
||||
# NB: This is mostly hitting the cache; Dynamo already converted these
|
||||
example_fake_inputs = [fake_mode.from_tensor(t) for t in example_inputs]
|
||||
flat_results_traced, out_spec_traced = pytree.tree_flatten(result_traced)
|
||||
|
||||
assert graph_captured_result is not None
|
||||
flat_both = list(graph_captured_result) + flat_args
|
||||
matched_output_elements_positions = produce_matching(flat_both, flat_results_traced)
|
||||
|
||||
if aten_graph:
|
||||
# Running graph with interpreter is needed for propagating the stack_trace
|
||||
def graph_with_interpreter(*args):
|
||||
with torch.fx.traceback.preserve_node_meta():
|
||||
return torch.fx.Interpreter(graph).run(*args)
|
||||
|
||||
with enable_python_dispatcher(), fake_mode:
|
||||
try:
|
||||
graph = make_fx(
|
||||
graph_with_interpreter,
|
||||
decomposition_table=decomposition_table,
|
||||
tracing_mode="real",
|
||||
_allow_non_fake_inputs=True,
|
||||
pre_dispatch=pre_dispatch,
|
||||
_allow_fake_constant=_allow_fake_constant,
|
||||
)(*example_fake_inputs)
|
||||
except CondOpArgsMismatchError as e:
|
||||
# Wrap the internal error to the user-facing error
|
||||
raise UserError(UserErrorType.DYNAMIC_CONTROL_FLOW, str(e))
|
||||
|
||||
new_graph = FlattenInputOutputSignature(
|
||||
graph,
|
||||
flat_args,
|
||||
matched_input_elements_positions,
|
||||
matched_output_elements_positions,
|
||||
example_fake_inputs,
|
||||
fake_mode,
|
||||
).transform()
|
||||
|
||||
# Store constraints and inputs as metadata for user passes, e.g. turn constraints to runtime check
|
||||
new_graph.meta["input_shape_constraints"] = (
|
||||
[constraint.serializable_spec for constraint in constraints]
|
||||
if constraints
|
||||
else []
|
||||
)
|
||||
|
||||
def signature_to_fullargspec(sig: inspect.Signature):
|
||||
# Get a list of Parameter objects from the Signature object
|
||||
params = list(sig.parameters.values())
|
||||
# Separate positional arguments, keyword-only arguments and varargs/varkw
|
||||
args = [
|
||||
p.name for p in params if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||
]
|
||||
kwonlyargs = [
|
||||
p.name for p in params if p.kind == inspect.Parameter.KEYWORD_ONLY
|
||||
]
|
||||
varargs = next(
|
||||
(p.name for p in params if p.kind == inspect.Parameter.VAR_POSITIONAL), None
|
||||
)
|
||||
varkw = next(
|
||||
(p.name for p in params if p.kind == inspect.Parameter.VAR_KEYWORD), None
|
||||
)
|
||||
# Get default values for positional arguments and keyword-only arguments
|
||||
defaults = tuple(
|
||||
p.default
|
||||
for p in params
|
||||
if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||
and p.default is not inspect.Parameter.empty
|
||||
)
|
||||
kwonlydefaults = {
|
||||
p.name: p.default
|
||||
for p in params
|
||||
if p.kind == inspect.Parameter.KEYWORD_ONLY
|
||||
and p.default is not inspect.Parameter.empty
|
||||
}
|
||||
# Get annotations for parameters and return value
|
||||
annotations = {}
|
||||
if sig.return_annotation:
|
||||
annotations = {"return": sig.return_annotation}
|
||||
for parameter in params:
|
||||
annotations[parameter.name] = parameter.annotation
|
||||
# Return a FullArgSpec object with the extracted attributes
|
||||
return inspect.FullArgSpec(
|
||||
args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations
|
||||
)
|
||||
|
||||
# Make dynamo graph to have same input/output spec as user code
|
||||
def argument_names(f: Callable[..., Any], *args, **kwargs) -> List[str]:
|
||||
fullargspec = signature_to_fullargspec(original_signature)
|
||||
|
||||
# 1. Map `args` 1-to-1 to positional arguments in original signature.
|
||||
input_strs = fullargspec.args[: len(args)]
|
||||
|
||||
if len(args) > len(fullargspec.args):
|
||||
# 2. If there are more arguments left in `args`, they map to varargs in original
|
||||
# signature. Assign names as {varargs}_0, {varargs}_1, ...
|
||||
assert fullargspec.varargs is not None, "More arguments than expected"
|
||||
input_strs += [
|
||||
f"{fullargspec.varargs}_{i}"
|
||||
for i in range(0, len(args) - len(input_strs))
|
||||
]
|
||||
elif len(args) < len(fullargspec.args):
|
||||
# 3. If there are fewer arguments in `args` than `fullargspec.args`,
|
||||
# it implies these are arguments either with default values, or provided in
|
||||
# `kwargs`. The former can be safely ignored. Because Dynamo.export does not
|
||||
# export them as part of the function signature. The latter will be handled
|
||||
# in the next step.
|
||||
for unprovided_arg in fullargspec.args[
|
||||
len(args) : -len(fullargspec.defaults or [])
|
||||
]:
|
||||
assert unprovided_arg in kwargs, f"Missing argument {unprovided_arg}"
|
||||
|
||||
# 4. Keyword arguments provided in `kwargs`.
|
||||
input_strs += list(kwargs.keys())
|
||||
|
||||
# 5. Keyword-only arguments with default values if not provided are not exported
|
||||
# as part of the function signature.
|
||||
for kwonly_arg in fullargspec.kwonlyargs:
|
||||
kwonlydefaults = fullargspec.kwonlydefaults or {}
|
||||
def inner(*args, **kwargs):
|
||||
fake_mode = _fake_mode
|
||||
f = _f
|
||||
assume_static_by_default = _assume_static_by_default
|
||||
check_if_dynamo_supported()
|
||||
torch._C._log_api_usage_once("torch._dynamo.export")
|
||||
if decomposition_table is not None:
|
||||
assert (
|
||||
kwonly_arg in kwargs or kwonly_arg in kwonlydefaults
|
||||
), f"Missing keyword only argument {kwonly_arg}"
|
||||
aten_graph
|
||||
), "Specifying a decomposition_table table or tracing mode is illegal without setting aten_graph=True"
|
||||
if pre_dispatch:
|
||||
assert aten_graph, "pre_dispatch=True can only be used when aten_graph=True"
|
||||
f = innermost_fn(f)
|
||||
call_to_inspect = f.forward if isinstance(f, torch.nn.Module) else f
|
||||
original_signature = inspect.signature(call_to_inspect)
|
||||
assert (
|
||||
not fake_mode or fake_mode.shape_env is not None
|
||||
), "The specified fake_mode must contain a valid shape_env"
|
||||
graph = None
|
||||
out_guards = None
|
||||
graph_captured_input = None
|
||||
graph_captured_result: Optional[Tuple[torch.Tensor, ...]] = None
|
||||
fake_mode = fake_mode or _guards.detect_fake_mode(args)
|
||||
_allow_fake_constant: bool = (
|
||||
fake_mode is not None
|
||||
) # Allow fake constants during symbolic tracing
|
||||
|
||||
return input_strs
|
||||
def produce_matching(source_args, candidate_args):
|
||||
matched_elements_positions = []
|
||||
dict_of_source_args = dict()
|
||||
for i in range(0, len(source_args)):
|
||||
element_id = id(source_args[i])
|
||||
dict_of_source_args[element_id] = i
|
||||
|
||||
new_graph.graph._codegen = _PyTreeCodeGen(
|
||||
_PyTreeInfo(
|
||||
argument_names(f, *args, **kwargs),
|
||||
in_spec,
|
||||
out_spec_traced,
|
||||
for i in range(0, len(candidate_args)):
|
||||
arg = candidate_args[i]
|
||||
# 1-element tensor arg can be unspec int/float
|
||||
if isinstance(arg, torch.Tensor) and torch.numel(arg) == 1:
|
||||
if id(arg) in dict_of_source_args:
|
||||
matched_elements_positions.append(dict_of_source_args[id(arg)])
|
||||
elif id(arg.item()) in dict_of_source_args:
|
||||
matched_elements_positions.append(
|
||||
dict_of_source_args[id(arg.item())]
|
||||
)
|
||||
else:
|
||||
raise AssertionError(
|
||||
"Dynamo input/output is not consistent with traced input/output"
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
id(arg) in dict_of_source_args
|
||||
), "Dynamo input and output is a strict subset of traced input/output"
|
||||
matched_elements_positions.append(dict_of_source_args[id(arg)])
|
||||
|
||||
return matched_elements_positions
|
||||
|
||||
def guard_export_print(guards: Set[_guards.Guard]):
|
||||
nonlocal out_guards
|
||||
assert (
|
||||
out_guards is None
|
||||
), "whole graph export entails exactly one guard export"
|
||||
out_guards = guards
|
||||
|
||||
example_inputs = []
|
||||
|
||||
def dynamo_normalization_capturing_compiler(
|
||||
gm: torch.fx.GraphModule, inner_example_inputs
|
||||
):
|
||||
nonlocal graph
|
||||
assert (
|
||||
graph is None
|
||||
), "Tried to emit a second graph during export. Tracing through 'f' must produce a single graph."
|
||||
graph = gm
|
||||
|
||||
nonlocal fake_mode, example_inputs
|
||||
fake_mode = fake_mode or _guards.detect_fake_mode(inner_example_inputs)
|
||||
example_inputs = inner_example_inputs
|
||||
|
||||
def result_capturing_wrapper(*graph_inputs):
|
||||
nonlocal graph_captured_result
|
||||
nonlocal graph_captured_input
|
||||
|
||||
graph_captured_input = graph_inputs
|
||||
assert graph is not None
|
||||
graph_captured_result = graph(*graph_inputs)
|
||||
return graph_captured_result
|
||||
|
||||
return result_capturing_wrapper
|
||||
|
||||
flat_args, in_spec = pytree.tree_flatten((args, kwargs))
|
||||
|
||||
remove_from_cache(f)
|
||||
constraint_violation_error = None
|
||||
if tracing_mode != "symbolic":
|
||||
assume_static_by_default = True
|
||||
with patch(f"{__name__}.most_recent_backend", None), config.patch(
|
||||
specialize_int=True,
|
||||
assume_static_by_default=assume_static_by_default,
|
||||
automatic_dynamic_shapes=False,
|
||||
capture_dynamic_output_shape_ops=True,
|
||||
capture_scalar_outputs=True,
|
||||
), torch._guards.export_fake_mode(fake_mode):
|
||||
opt_f = optimize_assert(
|
||||
dynamo_normalization_capturing_compiler,
|
||||
hooks=Hooks(
|
||||
guard_export_fn=guard_export_print,
|
||||
guard_fail_fn=None,
|
||||
),
|
||||
export=True,
|
||||
export_constraints=constraints,
|
||||
)(f)
|
||||
# TODO(voz): We may have instances of `f` that mutate inputs, we should track sideffects and reject.
|
||||
try:
|
||||
result_traced = opt_f(*args, **kwargs)
|
||||
except ConstraintViolationError as e:
|
||||
constraint_violation_error = e
|
||||
remove_from_cache(f)
|
||||
|
||||
if (
|
||||
(shape_env := getattr(fake_mode, "shape_env", None)) is not None
|
||||
and (dim_constraints := shape_env.dim_constraints) is not None
|
||||
and not skipfiles.check(inspect.getsourcefile(call_to_inspect))
|
||||
):
|
||||
dim_constraints.solve()
|
||||
msg = dim_constraints.prettify_results(original_signature)
|
||||
forced_specializations = dim_constraints.forced_specializations()
|
||||
if forced_specializations:
|
||||
msg = (
|
||||
"Some dynamic dimensions need to be specialized because "
|
||||
"the constraints inferred for them are too complex to specify.\n"
|
||||
f"{forced_specializations}\n{msg}"
|
||||
)
|
||||
if constraint_violation_error:
|
||||
constraint_violation_error.args = (
|
||||
constraint_violation_error.args[0] + msg,
|
||||
)
|
||||
else:
|
||||
if forced_specializations:
|
||||
constraint_violation_error = ConstraintViolationError(msg)
|
||||
else:
|
||||
log.info(
|
||||
"Summary of dimension constraints:%s",
|
||||
msg,
|
||||
)
|
||||
|
||||
# Error if we have any constraints on static values
|
||||
for k in shape_env.var_to_range.keys():
|
||||
if isinstance(k, sympy.Integer):
|
||||
constraint_violation_error = ConstraintViolationError(
|
||||
f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
|
||||
"It appears that you're trying to set a constraint on a "
|
||||
f"value which we evaluated to have a static value of {k}. "
|
||||
"Scroll up to see where this constraint was set."
|
||||
)
|
||||
if constraint_violation_error:
|
||||
raise constraint_violation_error
|
||||
|
||||
assert (
|
||||
graph is not None
|
||||
), "Failed to produce a graph during tracing. Tracing through 'f' must produce a single graph."
|
||||
assert out_guards is not None, "Failed to produce guards during tracing"
|
||||
assert fake_mode is not None
|
||||
|
||||
matched_input_elements_positions = produce_matching(
|
||||
flat_args, graph_captured_input
|
||||
)
|
||||
)
|
||||
|
||||
new_graph.recompile()
|
||||
return (new_graph, out_guards)
|
||||
# NB: This is mostly hitting the cache; Dynamo already converted these
|
||||
example_fake_inputs = [fake_mode.from_tensor(t) for t in example_inputs]
|
||||
flat_results_traced, out_spec_traced = pytree.tree_flatten(result_traced)
|
||||
|
||||
assert graph_captured_result is not None
|
||||
flat_both = list(graph_captured_result) + flat_args
|
||||
matched_output_elements_positions = produce_matching(
|
||||
flat_both, flat_results_traced
|
||||
)
|
||||
|
||||
if aten_graph:
|
||||
# Running graph with interpreter is needed for propagating the stack_trace
|
||||
def graph_with_interpreter(*args):
|
||||
with torch.fx.traceback.preserve_node_meta():
|
||||
return torch.fx.Interpreter(graph).run(*args)
|
||||
|
||||
with enable_python_dispatcher(), fake_mode:
|
||||
try:
|
||||
graph = make_fx(
|
||||
graph_with_interpreter,
|
||||
decomposition_table=decomposition_table,
|
||||
tracing_mode="real",
|
||||
_allow_non_fake_inputs=True,
|
||||
pre_dispatch=pre_dispatch,
|
||||
_allow_fake_constant=_allow_fake_constant,
|
||||
)(*example_fake_inputs)
|
||||
except CondOpArgsMismatchError as e:
|
||||
# Wrap the internal error to the user-facing error
|
||||
raise UserError(UserErrorType.DYNAMIC_CONTROL_FLOW, str(e))
|
||||
|
||||
new_graph = FlattenInputOutputSignature(
|
||||
graph,
|
||||
flat_args,
|
||||
matched_input_elements_positions,
|
||||
matched_output_elements_positions,
|
||||
example_fake_inputs,
|
||||
fake_mode,
|
||||
).transform()
|
||||
|
||||
# Store constraints and inputs as metadata for user passes, e.g. turn constraints to runtime check
|
||||
new_graph.meta["input_shape_constraints"] = (
|
||||
[constraint.serializable_spec for constraint in constraints]
|
||||
if constraints
|
||||
else []
|
||||
)
|
||||
|
||||
def signature_to_fullargspec(sig: inspect.Signature):
|
||||
# Get a list of Parameter objects from the Signature object
|
||||
params = list(sig.parameters.values())
|
||||
# Separate positional arguments, keyword-only arguments and varargs/varkw
|
||||
args = [
|
||||
p.name
|
||||
for p in params
|
||||
if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||
]
|
||||
kwonlyargs = [
|
||||
p.name for p in params if p.kind == inspect.Parameter.KEYWORD_ONLY
|
||||
]
|
||||
varargs = next(
|
||||
(p.name for p in params if p.kind == inspect.Parameter.VAR_POSITIONAL),
|
||||
None,
|
||||
)
|
||||
varkw = next(
|
||||
(p.name for p in params if p.kind == inspect.Parameter.VAR_KEYWORD),
|
||||
None,
|
||||
)
|
||||
# Get default values for positional arguments and keyword-only arguments
|
||||
defaults = tuple(
|
||||
p.default
|
||||
for p in params
|
||||
if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||
and p.default is not inspect.Parameter.empty
|
||||
)
|
||||
kwonlydefaults = {
|
||||
p.name: p.default
|
||||
for p in params
|
||||
if p.kind == inspect.Parameter.KEYWORD_ONLY
|
||||
and p.default is not inspect.Parameter.empty
|
||||
}
|
||||
# Get annotations for parameters and return value
|
||||
annotations = {}
|
||||
if sig.return_annotation:
|
||||
annotations = {"return": sig.return_annotation}
|
||||
for parameter in params:
|
||||
annotations[parameter.name] = parameter.annotation
|
||||
# Return a FullArgSpec object with the extracted attributes
|
||||
return inspect.FullArgSpec(
|
||||
args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations
|
||||
)
|
||||
|
||||
# Make dynamo graph to have same input/output spec as user code
|
||||
def argument_names(f: Callable[..., Any], *args, **kwargs) -> List[str]:
|
||||
fullargspec = signature_to_fullargspec(original_signature)
|
||||
|
||||
# 1. Map `args` 1-to-1 to positional arguments in original signature.
|
||||
input_strs = fullargspec.args[: len(args)]
|
||||
|
||||
if len(args) > len(fullargspec.args):
|
||||
# 2. If there are more arguments left in `args`, they map to varargs in original
|
||||
# signature. Assign names as {varargs}_0, {varargs}_1, ...
|
||||
assert fullargspec.varargs is not None, "More arguments than expected"
|
||||
input_strs += [
|
||||
f"{fullargspec.varargs}_{i}"
|
||||
for i in range(0, len(args) - len(input_strs))
|
||||
]
|
||||
elif len(args) < len(fullargspec.args):
|
||||
# 3. If there are fewer arguments in `args` than `fullargspec.args`,
|
||||
# it implies these are arguments either with default values, or provided in
|
||||
# `kwargs`. The former can be safely ignored. Because Dynamo.export does not
|
||||
# export them as part of the function signature. The latter will be handled
|
||||
# in the next step.
|
||||
for unprovided_arg in fullargspec.args[
|
||||
len(args) : -len(fullargspec.defaults or [])
|
||||
]:
|
||||
assert (
|
||||
unprovided_arg in kwargs
|
||||
), f"Missing argument {unprovided_arg}"
|
||||
|
||||
# 4. Keyword arguments provided in `kwargs`.
|
||||
input_strs += list(kwargs.keys())
|
||||
|
||||
# 5. Keyword-only arguments with default values if not provided are not exported
|
||||
# as part of the function signature.
|
||||
for kwonly_arg in fullargspec.kwonlyargs:
|
||||
kwonlydefaults = fullargspec.kwonlydefaults or {}
|
||||
assert (
|
||||
kwonly_arg in kwargs or kwonly_arg in kwonlydefaults
|
||||
), f"Missing keyword only argument {kwonly_arg}"
|
||||
|
||||
return input_strs
|
||||
|
||||
new_graph.graph._codegen = _PyTreeCodeGen(
|
||||
_PyTreeInfo(
|
||||
argument_names(f, *args, **kwargs),
|
||||
in_spec,
|
||||
out_spec_traced,
|
||||
)
|
||||
)
|
||||
|
||||
new_graph.recompile()
|
||||
return (new_graph, out_guards)
|
||||
|
||||
if extra_args or extra_kwargs:
|
||||
warnings.warn(
|
||||
"export(f, *args, **kwargs) is deprecated, use export(f)(*args, **kwargs) instead. "
|
||||
"If you don't migrate, we may break your export call in the future if your user defined kwargs "
|
||||
"conflict with future kwargs added to export(f)."
|
||||
)
|
||||
return inner(*extra_args, **extra_kwargs)
|
||||
else:
|
||||
return inner
|
||||
|
||||
|
||||
def optimize_assert(
|
||||
|
|
|
|||
|
|
@ -62,7 +62,6 @@ from .passes.add_runtime_assertions_for_constraints_pass import _AddRuntimeAsser
|
|||
#
|
||||
# result = torch._dynamo.export(
|
||||
# my_model,
|
||||
# *sixtyfour_tensors,
|
||||
# constraints=[
|
||||
# # if you do only dynamic_dim, this is sugar for
|
||||
# # -Inf <= dynamic_dim(blah, 0) <= Inf; we don’t otherwise
|
||||
|
|
@ -74,6 +73,8 @@ from .passes.add_runtime_assertions_for_constraints_pass import _AddRuntimeAsser
|
|||
# # NB: But we actually truncate ranges to be >= 2, because of
|
||||
# # 0/1 specialization
|
||||
# ]
|
||||
# )(
|
||||
# *sixtyfour_tensors,
|
||||
# )
|
||||
def dynamic_dim(t: torch.Tensor, index: int):
|
||||
if not isinstance(t, torch.Tensor):
|
||||
|
|
@ -152,10 +153,11 @@ def export(
|
|||
try:
|
||||
gm_torch_level, _ = torch._dynamo.export(
|
||||
f,
|
||||
*args,
|
||||
constraints=constraints,
|
||||
assume_static_by_default=True,
|
||||
tracing_mode="symbolic",
|
||||
)(
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -141,9 +141,10 @@ def get_aten_graph_module(
|
|||
import torch._dynamo
|
||||
aten_pattern, _ = torch._dynamo.export(
|
||||
pattern,
|
||||
*copy.deepcopy(example_inputs),
|
||||
aten_graph=True,
|
||||
tracing_mode="real",
|
||||
)(
|
||||
*copy.deepcopy(example_inputs),
|
||||
**kwargs,
|
||||
)
|
||||
aten_pattern.graph.eliminate_dead_code()
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ def _mark_nodes_as_annotated(nodes: List[Node]):
|
|||
|
||||
|
||||
def _get_dynamo_graph(function: Callable, inputs) -> torch.fx.Graph:
|
||||
gm, _ = torchdynamo.export(function, *inputs, aten_graph=True)
|
||||
gm, _ = torchdynamo.export(function, aten_graph=True)(*inputs)
|
||||
gm.graph.eliminate_dead_code()
|
||||
return gm.graph
|
||||
|
||||
|
|
|
|||
|
|
@ -191,9 +191,10 @@ class DynamoExport(exporter.FXGraphExtractor):
|
|||
fx_mode = "symbolic" if options.dynamic_shapes else "fake"
|
||||
graph_module, graph_guard = torch._dynamo.export(
|
||||
wrapped_model,
|
||||
*model_args,
|
||||
tracing_mode=fx_mode,
|
||||
fake_mode=fake_mode, # type: ignore[arg-type]
|
||||
)(
|
||||
*model_args,
|
||||
**model_kwargs,
|
||||
)
|
||||
del graph_guard # Unused
|
||||
|
|
|
|||
|
|
@ -798,7 +798,7 @@ class Modularize(_pass.Transform):
|
|||
>>> out = self.linear(out)
|
||||
>>> return out
|
||||
>>>
|
||||
>>> gm, _ = torch._dynamo.export(TestModule(), torch.tensor([0, 1, 2]), aten_graph=True)
|
||||
>>> gm, _ = torch._dynamo.export(TestModule(), aten_graph=True)(torch.tensor([0, 1, 2]))
|
||||
>>> gm.print_readable()
|
||||
|
||||
>>> gm = passes.Modularize(infra.DiagnosticContext("test_context", "1.0"), gm).run()
|
||||
|
|
|
|||
Loading…
Reference in a new issue