Change _dynamo.export to be export(f)(*args, **kwargs) (#106109)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106109
Approved by: https://github.com/voznesenskym
This commit is contained in:
Edward Z. Yang 2023-07-27 14:39:15 -07:00 committed by PyTorch MergeBot
parent 5cbd3fc412
commit 7b9d250f06
16 changed files with 521 additions and 503 deletions

View file

@ -23,7 +23,7 @@ with torch._dynamo.config.patch(dynamic_shapes=False):
torch._dynamo.reset()
with torch.no_grad():
module, _ = torch._dynamo.export(Net().cuda(), x, y)
module, _ = torch._dynamo.export(Net().cuda())(x, y)
lib_path = torch._inductor.aot_compile(module, [x, y])
shutil.copy(lib_path, "libaot_inductor_output.so")

View file

@ -154,7 +154,7 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
real = mod(rx)
# Run it in export
graph, _ = torch._dynamo.export(mod, rx)
graph, _ = torch._dynamo.export(mod)(rx)
# Run exported graph with AOT
self.assertTrue(torch._dynamo.testing.same(real, graph(rx)))
@ -185,7 +185,7 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
real = mod(x, y)
# Run it in export
graph, _ = torch._dynamo.export(mod, x, y)
graph, _ = torch._dynamo.export(mod)(x, y)
# Assert equal
self.assertTrue(torch._dynamo.testing.same(real, graph(x, y)))

View file

@ -238,7 +238,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
real_device = real.device
real_dtype = real.dtype
graph, guards = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
exported = graph(torch.tensor([0.5]))
self.assertEqual(exported.device, real_device)
self.assertEqual(exported.dtype, real_dtype)
@ -263,7 +263,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
real_device = real.device
real_dtype = real.dtype
graph, _ = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
exported = graph(torch.tensor([0.5]))
self.assertEqual(exported.device, real_device)
self.assertEqual(exported.dtype, real_dtype)
@ -347,7 +347,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
real_device = real.device
real_dtype = real.dtype
graph, guards = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
exported = graph(torch.tensor([0.5]))
self.assertEqual(exported.device, real_device)
self.assertEqual(exported.dtype, real_dtype)
@ -521,7 +521,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
real_device = real.device
real_dtype = real.dtype
graph, guards = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
exported = graph(torch.tensor([0.5]))
self.assertEqual(exported.device, real_device)
self.assertEqual(exported.dtype, real_dtype)
@ -547,7 +547,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
real_device = real.device
real_dtype = real.dtype
graph, guards = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
exported = graph(torch.tensor([0.5]))
self.assertEqual(exported.device, real_device)
self.assertEqual(exported.dtype, real_dtype)

File diff suppressed because it is too large Load diff

View file

@ -10,11 +10,11 @@ from torch.testing._internal.common_utils import IS_FBCODE
class MutationExportTests(torch._dynamo.test_case.TestCase):
def check_failure_on_export(self, mod, *args):
with self.assertRaises(AssertionError):
torch._dynamo.export(mod, *args)
torch._dynamo.export(mod)(*args)
def check_same_with_export(self, mod, arg):
real_result = mod(arg)
graph, _ = torch._dynamo.export(mod, arg)
graph, _ = torch._dynamo.export(mod)(arg)
result = graph(arg)
self.assertTrue(torch._dynamo.utils.same(result, real_result))

View file

@ -3243,7 +3243,7 @@ def fn():
def f(x):
return 1 + torch._shape_as_tensor(x)[0]
gm, _ = torch._dynamo.export(f, torch.ones(6))
gm, _ = torch._dynamo.export(f)(torch.ones(6))
input_one_dim = torch.ones(6)
input_two_dims = torch.ones(7, 4)
@ -3583,8 +3583,8 @@ def fn():
def f(pred, pred2, x):
return cond(pred, true_fn, false_fn, [pred2, x])
graph, guard = torch._dynamo.export(
f, torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25])
graph, guard = torch._dynamo.export(f)(
torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25])
)
true_true_sin = graph(
torch.tensor(True), torch.tensor(True), torch.tensor([0.25, 0.25])
@ -3622,8 +3622,8 @@ def fn():
def f(pred, x):
return cond(pred, true_fn, false_fn, [x])
graph, guard = torch._dynamo.export(
f, torch.tensor(False), torch.tensor([0.25, 0.25])
graph, guard = torch._dynamo.export(f)(
torch.tensor(False), torch.tensor([0.25, 0.25])
)
true_mirror = graph(torch.tensor(True), torch.tensor([0.25, 0.25]))
self.assertTrue(same(torch.tensor([0.25, 0.25]), true_mirror))
@ -3891,7 +3891,7 @@ def fn():
return self.mod[0](x)
m = Mod()
graph, _ = torch._dynamo.export(m, torch.randn(3, 3))
graph, _ = torch._dynamo.export(m)(torch.randn(3, 3))
def test_nn_sequential_invocation(self):
with freeze_rng_state():
@ -3913,7 +3913,7 @@ def fn():
m = TestModel()
x = torch.rand((2, 2))
real = m(x)
graph, _ = torch._dynamo.export(m, x)
graph, _ = torch._dynamo.export(m)(x)
dynamo_result = graph(x)
self.assertTrue(same(real, dynamo_result))
@ -3937,7 +3937,7 @@ def fn():
m = TestModel()
x = torch.rand((2, 2))
real = m(x)
graph, _ = torch._dynamo.export(m, x)
graph, _ = torch._dynamo.export(m)(x)
dynamo_result = graph(x)
self.assertTrue(same(real, dynamo_result))
@ -4724,7 +4724,7 @@ def fn():
b.tag = "b"
b.frog = "ribbit"
exported = torch._dynamo.export(foo, a, b)
exported = torch._dynamo.export(foo)(a, b)
out_graph = exported[0]
nodes = list(out_graph.graph.nodes)
@ -4772,7 +4772,7 @@ def fn():
state[0].tag = "STATE_0"
state[1].tag = "HMMM"
exported = torch._dynamo.export(pre_attention_state_ops, i, mems, state)
exported = torch._dynamo.export(pre_attention_state_ops)(i, mems, state)
out_graph = exported[0]
nodes = list(out_graph.graph.nodes)

View file

@ -1410,7 +1410,7 @@ class NNModuleTests(torch._dynamo.test_case.TestCase):
mod = ModuleSpecialFwd()
rx = torch.randn([3, 10, 10])
real = mod(rx)
graph, _ = torch._dynamo.export(mod, rx)
graph, _ = torch._dynamo.export(mod)(rx)
self.assertTrue(torch._dynamo.testing.same(real, graph(rx)))
def test_conv_call_forward_directly(self):

View file

@ -1188,7 +1188,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
self.assertEqual(cnt.frame_count, 2)
self.assertEqual(cnt.op_count, 3) # rand, rand
try:
graph, _ = torch._dynamo.export(fn)
graph, _ = torch._dynamo.export(fn)()
# See https://github.com/pytorch/pytorch/pull/87490
self.fail("unexpected export success")
except torch._dynamo.exc.Unsupported:
@ -2650,11 +2650,11 @@ class ReproTests(torch._dynamo.test_case.TestCase):
self.assertEqual(cnt.op_count, 6)
self.assertEqual(cnt.frame_count, 1)
exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5]))
exported, _ = torch._dynamo.export(f)(torch.Tensor([3, 4, 5]))
self.assertTrue(same(exported(*args), f(*args)))
with self.assertRaisesRegex(RuntimeError, "First dim need to be 3"):
exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5]))
exported, _ = torch._dynamo.export(f)(torch.Tensor([4, 4, 5]))
def test_not_rewrite_assert_for_other_errors(self):
def f(x):
@ -2686,11 +2686,11 @@ class ReproTests(torch._dynamo.test_case.TestCase):
return x.cos() + b
args = (torch.Tensor([3, 4, 5]),)
exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5]))
exported, _ = torch._dynamo.export(f)(torch.Tensor([3, 4, 5]))
self.assertTrue(same(exported(*args), f(*args)))
with self.assertRaisesRegex(RuntimeError, "assertion error"):
exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5]))
exported, _ = torch._dynamo.export(f)(torch.Tensor([4, 4, 5]))
def test_rewrite_assert_with_non_string_msg(self):
def f(x):
@ -2718,7 +2718,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
return x.cos() + b
args = (torch.Tensor([3, 4, 5]),)
exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5]))
exported, _ = torch._dynamo.export(f)(torch.Tensor([3, 4, 5]))
self.assertTrue(same(exported(*args), f(*args)))
cnt = torch._dynamo.testing.CompileCounter()
@ -2728,7 +2728,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
self.assertEqual(cnt.op_count, 3)
self.assertEqual(cnt.frame_count, 1)
exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5]))
exported, _ = torch._dynamo.export(f)(torch.Tensor([4, 4, 5]))
self.assertTrue(same(exported(*args), f(*args)))
def test_size_typematch(self):
@ -2885,7 +2885,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
inp = torch.randn(6, 5)
gm, _ = torch._dynamo.export(f, torch.randn(4, 5), aten_graph=True)
gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.randn(4, 5))
self.assertEqual(gm(inp).shape, f(inp).shape)
@torch._dynamo.config.patch("specialize_int", False)
@ -3058,9 +3058,10 @@ class ReproTests(torch._dynamo.test_case.TestCase):
gm, _ = torch._dynamo.export(
f,
aten_graph=True,
)(
torch.zeros(6, 4),
torch.tensor(1),
aten_graph=True,
)
self.assertEqual(
f(torch.zeros(6, 4), torch.tensor(1)),
@ -3158,8 +3159,9 @@ class ReproTests(torch._dynamo.test_case.TestCase):
gm, _ = torch._dynamo.export(
f,
torch.zeros(6, 4),
aten_graph=True,
)(
torch.zeros(6, 4),
)
self.assertEqual(f(torch.ones(8, 4)), gm(torch.ones(8, 4)))

View file

@ -30,7 +30,7 @@ class TestSourceMatcher(JitTestCase):
return x
inputs = (torch.randn(3, 3),)
gm, _ = torch._dynamo.export(M(), *inputs, aten_graph=True)
gm, _ = torch._dynamo.export(M(), aten_graph=True)(*inputs)
gm.graph.eliminate_dead_code()
module_partitions = get_source_partitions(gm.graph, [torch.nn.Linear, torch.nn.ReLU])
@ -69,7 +69,7 @@ class TestSourceMatcher(JitTestCase):
return self.maxpool(self.relu(z))
inputs = (torch.randn(1, 3, 256, 256),)
gm, _ = torch._dynamo.export(M(torch.ones(1, 16, 256, 256)), *inputs, aten_graph=True)
gm, _ = torch._dynamo.export(M(torch.ones(1, 16, 256, 256)), aten_graph=True)(*inputs)
gm.graph.eliminate_dead_code()
module_partitions = get_source_partitions(gm.graph, [torch.nn.Conv2d, torch.nn.ReLU, torch.nn.MaxPool2d])
@ -111,7 +111,7 @@ class TestSourceMatcher(JitTestCase):
return x
inputs = (torch.randn(1, 3, 5, 5), torch.rand(3, 3, 3, 3), torch.rand(3))
gm, _ = torch._dynamo.export(M(), *inputs, aten_graph=True)
gm, _ = torch._dynamo.export(M(), aten_graph=True)(*inputs)
gm.graph.eliminate_dead_code()
module_partitions = get_source_partitions(gm.graph, [torch.nn.functional.conv2d])
@ -135,7 +135,7 @@ class TestSourceMatcher(JitTestCase):
return x
inputs = (torch.randn(1, 5), torch.rand((5, 5)), torch.zeros(5))
gm, _ = torch._dynamo.export(M(), *inputs, aten_graph=True)
gm, _ = torch._dynamo.export(M(), aten_graph=True)(*inputs)
gm.graph.eliminate_dead_code()
module_partitions = get_source_partitions(gm.graph, [torch.nn.functional.linear, torch.nn.functional.relu])

View file

@ -16,7 +16,7 @@ class TestFxPasses(common_utils.TestCase):
x = torch.randn(3)
y = torch.randn(3)
z = torch.randn(3)
gm, _ = torch._dynamo.export(func, x, y, z)
gm, _ = torch._dynamo.export(func)(x, y, z)
torch._dynamo.reset()
# Purposely name the nodes in a way that will cause a recursive collision later.
@ -44,7 +44,7 @@ class TestFxPasses(common_utils.TestCase):
x = torch.randn(3)
y = torch.randn(3)
z = torch.randn(3)
gm, _ = torch._dynamo.export(func, x, y, z)
gm, _ = torch._dynamo.export(func)(x, y, z)
torch._dynamo.reset()
# Run `set_node_name` and verify that the names are correct.

View file

@ -833,7 +833,7 @@ class FlattenInputOutputSignature(torch.fx.interpreter.Transformer):
def export(
f: Callable[..., Any],
*args,
*extra_args,
aten_graph: bool = False,
pre_dispatch: bool = False,
decomposition_table: Optional[
@ -843,16 +843,14 @@ def export(
constraints: Optional[List[Constraint]] = None,
assume_static_by_default: bool = False,
fake_mode: fake_tensor.FakeTensorMode = None,
**kwargs,
) -> Tuple[torch.fx.GraphModule, Set[_guards.Guard]]:
**extra_kwargs,
) -> Callable[..., Tuple[torch.fx.GraphModule, Set[_guards.Guard]]]:
"""
Export an input function f to a format that can be executed outside of PyTorch using the FX graph.
Args:
f (callable): A PyTorch function to be exported.
*args: Variable length argument list to be passed to the function f.
aten_graph (bool): If True, exports a graph with ATen operators.
If False, exports a graph with Python operators. Default is False.
@ -872,10 +870,8 @@ def export(
Useful during symbolic tracing, when user input is already fakefied. Implies free fake tensors
are allowed on `make_fx`. `fake_mode` must contain a valid (not None) `shape_env` instance.
**kwargs: Arbitrary keyword arguments to be passed to the function f.
Returns:
A tuple of (graph, guards)
A function that given args and kwargs, returns a tuple of (graph, guards)
Graph: An FX graph representing the execution of the input PyTorch function with the provided arguments and options.
Guards: The guards we accumulated during tracing f above
@ -887,299 +883,329 @@ def export(
Note - this headerdoc was authored by ChatGPT, with slight modifications by the author.
"""
check_if_dynamo_supported()
torch._C._log_api_usage_once("torch._dynamo.export")
if decomposition_table is not None:
assert (
aten_graph
), "Specifying a decomposition_table table or tracing mode is illegal without setting aten_graph=True"
if pre_dispatch:
assert aten_graph, "pre_dispatch=True can only be used when aten_graph=True"
f = innermost_fn(f)
call_to_inspect = f.forward if isinstance(f, torch.nn.Module) else f
original_signature = inspect.signature(call_to_inspect)
# Deal with "local variable referenced before assignment"
_fake_mode = fake_mode
_f = f
_assume_static_by_default = assume_static_by_default
assert (
not fake_mode or fake_mode.shape_env is not None
), "The specified fake_mode must contain a valid shape_env"
graph = None
out_guards = None
graph_captured_input = None
graph_captured_result: Optional[Tuple[torch.Tensor, ...]] = None
fake_mode = fake_mode or _guards.detect_fake_mode(args)
_allow_fake_constant: bool = (
fake_mode is not None
) # Allow fake constants during symbolic tracing
def produce_matching(source_args, candidate_args):
matched_elements_positions = []
dict_of_source_args = dict()
for i in range(0, len(source_args)):
element_id = id(source_args[i])
dict_of_source_args[element_id] = i
for i in range(0, len(candidate_args)):
arg = candidate_args[i]
# 1-element tensor arg can be unspec int/float
if isinstance(arg, torch.Tensor) and torch.numel(arg) == 1:
if id(arg) in dict_of_source_args:
matched_elements_positions.append(dict_of_source_args[id(arg)])
elif id(arg.item()) in dict_of_source_args:
matched_elements_positions.append(
dict_of_source_args[id(arg.item())]
)
else:
raise AssertionError(
"Dynamo input/output is not consistent with traced input/output"
)
else:
assert (
id(arg) in dict_of_source_args
), "Dynamo input and output is a strict subset of traced input/output"
matched_elements_positions.append(dict_of_source_args[id(arg)])
return matched_elements_positions
def guard_export_print(guards: Set[_guards.Guard]):
nonlocal out_guards
assert out_guards is None, "whole graph export entails exactly one guard export"
out_guards = guards
example_inputs = []
def dynamo_normalization_capturing_compiler(
gm: torch.fx.GraphModule, inner_example_inputs
):
nonlocal graph
assert (
graph is None
), "Tried to emit a second graph during export. Tracing through 'f' must produce a single graph."
graph = gm
nonlocal fake_mode, example_inputs
fake_mode = fake_mode or _guards.detect_fake_mode(inner_example_inputs)
example_inputs = inner_example_inputs
def result_capturing_wrapper(*graph_inputs):
nonlocal graph_captured_result
nonlocal graph_captured_input
graph_captured_input = graph_inputs
assert graph is not None
graph_captured_result = graph(*graph_inputs)
return graph_captured_result
return result_capturing_wrapper
flat_args, in_spec = pytree.tree_flatten((args, kwargs))
remove_from_cache(f)
constraint_violation_error = None
if tracing_mode != "symbolic":
assume_static_by_default = True
with patch(f"{__name__}.most_recent_backend", None), config.patch(
specialize_int=True,
assume_static_by_default=assume_static_by_default,
automatic_dynamic_shapes=False,
capture_dynamic_output_shape_ops=True,
capture_scalar_outputs=True,
), torch._guards.export_fake_mode(fake_mode):
opt_f = optimize_assert(
dynamo_normalization_capturing_compiler,
hooks=Hooks(
guard_export_fn=guard_export_print,
guard_fail_fn=None,
),
export=True,
export_constraints=constraints,
)(f)
# TODO(voz): We may have instances of `f` that mutate inputs, we should track sideffects and reject.
try:
result_traced = opt_f(*args, **kwargs)
except ConstraintViolationError as e:
constraint_violation_error = e
remove_from_cache(f)
if (
(shape_env := getattr(fake_mode, "shape_env", None)) is not None
and (dim_constraints := shape_env.dim_constraints) is not None
and not skipfiles.check(inspect.getsourcefile(call_to_inspect))
):
dim_constraints.solve()
msg = dim_constraints.prettify_results(original_signature)
forced_specializations = dim_constraints.forced_specializations()
if forced_specializations:
msg = (
"Some dynamic dimensions need to be specialized because "
"the constraints inferred for them are too complex to specify.\n"
f"{forced_specializations}\n{msg}"
)
if constraint_violation_error:
constraint_violation_error.args = (
constraint_violation_error.args[0] + msg,
)
else:
if forced_specializations:
constraint_violation_error = ConstraintViolationError(msg)
else:
log.info(
"Summary of dimension constraints:%s",
msg,
)
# Error if we have any constraints on static values
for k in shape_env.var_to_range.keys():
if isinstance(k, sympy.Integer):
constraint_violation_error = ConstraintViolationError(
f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
"It appears that you're trying to set a constraint on a "
f"value which we evaluated to have a static value of {k}. "
"Scroll up to see where this constraint was set."
)
if constraint_violation_error:
raise constraint_violation_error
assert (
graph is not None
), "Failed to produce a graph during tracing. Tracing through 'f' must produce a single graph."
assert out_guards is not None, "Failed to produce guards during tracing"
assert fake_mode is not None
matched_input_elements_positions = produce_matching(flat_args, graph_captured_input)
# NB: This is mostly hitting the cache; Dynamo already converted these
example_fake_inputs = [fake_mode.from_tensor(t) for t in example_inputs]
flat_results_traced, out_spec_traced = pytree.tree_flatten(result_traced)
assert graph_captured_result is not None
flat_both = list(graph_captured_result) + flat_args
matched_output_elements_positions = produce_matching(flat_both, flat_results_traced)
if aten_graph:
# Running graph with interpreter is needed for propagating the stack_trace
def graph_with_interpreter(*args):
with torch.fx.traceback.preserve_node_meta():
return torch.fx.Interpreter(graph).run(*args)
with enable_python_dispatcher(), fake_mode:
try:
graph = make_fx(
graph_with_interpreter,
decomposition_table=decomposition_table,
tracing_mode="real",
_allow_non_fake_inputs=True,
pre_dispatch=pre_dispatch,
_allow_fake_constant=_allow_fake_constant,
)(*example_fake_inputs)
except CondOpArgsMismatchError as e:
# Wrap the internal error to the user-facing error
raise UserError(UserErrorType.DYNAMIC_CONTROL_FLOW, str(e))
new_graph = FlattenInputOutputSignature(
graph,
flat_args,
matched_input_elements_positions,
matched_output_elements_positions,
example_fake_inputs,
fake_mode,
).transform()
# Store constraints and inputs as metadata for user passes, e.g. turn constraints to runtime check
new_graph.meta["input_shape_constraints"] = (
[constraint.serializable_spec for constraint in constraints]
if constraints
else []
)
def signature_to_fullargspec(sig: inspect.Signature):
# Get a list of Parameter objects from the Signature object
params = list(sig.parameters.values())
# Separate positional arguments, keyword-only arguments and varargs/varkw
args = [
p.name for p in params if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
]
kwonlyargs = [
p.name for p in params if p.kind == inspect.Parameter.KEYWORD_ONLY
]
varargs = next(
(p.name for p in params if p.kind == inspect.Parameter.VAR_POSITIONAL), None
)
varkw = next(
(p.name for p in params if p.kind == inspect.Parameter.VAR_KEYWORD), None
)
# Get default values for positional arguments and keyword-only arguments
defaults = tuple(
p.default
for p in params
if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
and p.default is not inspect.Parameter.empty
)
kwonlydefaults = {
p.name: p.default
for p in params
if p.kind == inspect.Parameter.KEYWORD_ONLY
and p.default is not inspect.Parameter.empty
}
# Get annotations for parameters and return value
annotations = {}
if sig.return_annotation:
annotations = {"return": sig.return_annotation}
for parameter in params:
annotations[parameter.name] = parameter.annotation
# Return a FullArgSpec object with the extracted attributes
return inspect.FullArgSpec(
args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations
)
# Make dynamo graph to have same input/output spec as user code
def argument_names(f: Callable[..., Any], *args, **kwargs) -> List[str]:
fullargspec = signature_to_fullargspec(original_signature)
# 1. Map `args` 1-to-1 to positional arguments in original signature.
input_strs = fullargspec.args[: len(args)]
if len(args) > len(fullargspec.args):
# 2. If there are more arguments left in `args`, they map to varargs in original
# signature. Assign names as {varargs}_0, {varargs}_1, ...
assert fullargspec.varargs is not None, "More arguments than expected"
input_strs += [
f"{fullargspec.varargs}_{i}"
for i in range(0, len(args) - len(input_strs))
]
elif len(args) < len(fullargspec.args):
# 3. If there are fewer arguments in `args` than `fullargspec.args`,
# it implies these are arguments either with default values, or provided in
# `kwargs`. The former can be safely ignored. Because Dynamo.export does not
# export them as part of the function signature. The latter will be handled
# in the next step.
for unprovided_arg in fullargspec.args[
len(args) : -len(fullargspec.defaults or [])
]:
assert unprovided_arg in kwargs, f"Missing argument {unprovided_arg}"
# 4. Keyword arguments provided in `kwargs`.
input_strs += list(kwargs.keys())
# 5. Keyword-only arguments with default values if not provided are not exported
# as part of the function signature.
for kwonly_arg in fullargspec.kwonlyargs:
kwonlydefaults = fullargspec.kwonlydefaults or {}
def inner(*args, **kwargs):
fake_mode = _fake_mode
f = _f
assume_static_by_default = _assume_static_by_default
check_if_dynamo_supported()
torch._C._log_api_usage_once("torch._dynamo.export")
if decomposition_table is not None:
assert (
kwonly_arg in kwargs or kwonly_arg in kwonlydefaults
), f"Missing keyword only argument {kwonly_arg}"
aten_graph
), "Specifying a decomposition_table table or tracing mode is illegal without setting aten_graph=True"
if pre_dispatch:
assert aten_graph, "pre_dispatch=True can only be used when aten_graph=True"
f = innermost_fn(f)
call_to_inspect = f.forward if isinstance(f, torch.nn.Module) else f
original_signature = inspect.signature(call_to_inspect)
assert (
not fake_mode or fake_mode.shape_env is not None
), "The specified fake_mode must contain a valid shape_env"
graph = None
out_guards = None
graph_captured_input = None
graph_captured_result: Optional[Tuple[torch.Tensor, ...]] = None
fake_mode = fake_mode or _guards.detect_fake_mode(args)
_allow_fake_constant: bool = (
fake_mode is not None
) # Allow fake constants during symbolic tracing
return input_strs
def produce_matching(source_args, candidate_args):
matched_elements_positions = []
dict_of_source_args = dict()
for i in range(0, len(source_args)):
element_id = id(source_args[i])
dict_of_source_args[element_id] = i
new_graph.graph._codegen = _PyTreeCodeGen(
_PyTreeInfo(
argument_names(f, *args, **kwargs),
in_spec,
out_spec_traced,
for i in range(0, len(candidate_args)):
arg = candidate_args[i]
# 1-element tensor arg can be unspec int/float
if isinstance(arg, torch.Tensor) and torch.numel(arg) == 1:
if id(arg) in dict_of_source_args:
matched_elements_positions.append(dict_of_source_args[id(arg)])
elif id(arg.item()) in dict_of_source_args:
matched_elements_positions.append(
dict_of_source_args[id(arg.item())]
)
else:
raise AssertionError(
"Dynamo input/output is not consistent with traced input/output"
)
else:
assert (
id(arg) in dict_of_source_args
), "Dynamo input and output is a strict subset of traced input/output"
matched_elements_positions.append(dict_of_source_args[id(arg)])
return matched_elements_positions
def guard_export_print(guards: Set[_guards.Guard]):
nonlocal out_guards
assert (
out_guards is None
), "whole graph export entails exactly one guard export"
out_guards = guards
example_inputs = []
def dynamo_normalization_capturing_compiler(
gm: torch.fx.GraphModule, inner_example_inputs
):
nonlocal graph
assert (
graph is None
), "Tried to emit a second graph during export. Tracing through 'f' must produce a single graph."
graph = gm
nonlocal fake_mode, example_inputs
fake_mode = fake_mode or _guards.detect_fake_mode(inner_example_inputs)
example_inputs = inner_example_inputs
def result_capturing_wrapper(*graph_inputs):
nonlocal graph_captured_result
nonlocal graph_captured_input
graph_captured_input = graph_inputs
assert graph is not None
graph_captured_result = graph(*graph_inputs)
return graph_captured_result
return result_capturing_wrapper
flat_args, in_spec = pytree.tree_flatten((args, kwargs))
remove_from_cache(f)
constraint_violation_error = None
if tracing_mode != "symbolic":
assume_static_by_default = True
with patch(f"{__name__}.most_recent_backend", None), config.patch(
specialize_int=True,
assume_static_by_default=assume_static_by_default,
automatic_dynamic_shapes=False,
capture_dynamic_output_shape_ops=True,
capture_scalar_outputs=True,
), torch._guards.export_fake_mode(fake_mode):
opt_f = optimize_assert(
dynamo_normalization_capturing_compiler,
hooks=Hooks(
guard_export_fn=guard_export_print,
guard_fail_fn=None,
),
export=True,
export_constraints=constraints,
)(f)
# TODO(voz): We may have instances of `f` that mutate inputs, we should track sideffects and reject.
try:
result_traced = opt_f(*args, **kwargs)
except ConstraintViolationError as e:
constraint_violation_error = e
remove_from_cache(f)
if (
(shape_env := getattr(fake_mode, "shape_env", None)) is not None
and (dim_constraints := shape_env.dim_constraints) is not None
and not skipfiles.check(inspect.getsourcefile(call_to_inspect))
):
dim_constraints.solve()
msg = dim_constraints.prettify_results(original_signature)
forced_specializations = dim_constraints.forced_specializations()
if forced_specializations:
msg = (
"Some dynamic dimensions need to be specialized because "
"the constraints inferred for them are too complex to specify.\n"
f"{forced_specializations}\n{msg}"
)
if constraint_violation_error:
constraint_violation_error.args = (
constraint_violation_error.args[0] + msg,
)
else:
if forced_specializations:
constraint_violation_error = ConstraintViolationError(msg)
else:
log.info(
"Summary of dimension constraints:%s",
msg,
)
# Error if we have any constraints on static values
for k in shape_env.var_to_range.keys():
if isinstance(k, sympy.Integer):
constraint_violation_error = ConstraintViolationError(
f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
"It appears that you're trying to set a constraint on a "
f"value which we evaluated to have a static value of {k}. "
"Scroll up to see where this constraint was set."
)
if constraint_violation_error:
raise constraint_violation_error
assert (
graph is not None
), "Failed to produce a graph during tracing. Tracing through 'f' must produce a single graph."
assert out_guards is not None, "Failed to produce guards during tracing"
assert fake_mode is not None
matched_input_elements_positions = produce_matching(
flat_args, graph_captured_input
)
)
new_graph.recompile()
return (new_graph, out_guards)
# NB: This is mostly hitting the cache; Dynamo already converted these
example_fake_inputs = [fake_mode.from_tensor(t) for t in example_inputs]
flat_results_traced, out_spec_traced = pytree.tree_flatten(result_traced)
assert graph_captured_result is not None
flat_both = list(graph_captured_result) + flat_args
matched_output_elements_positions = produce_matching(
flat_both, flat_results_traced
)
if aten_graph:
# Running graph with interpreter is needed for propagating the stack_trace
def graph_with_interpreter(*args):
with torch.fx.traceback.preserve_node_meta():
return torch.fx.Interpreter(graph).run(*args)
with enable_python_dispatcher(), fake_mode:
try:
graph = make_fx(
graph_with_interpreter,
decomposition_table=decomposition_table,
tracing_mode="real",
_allow_non_fake_inputs=True,
pre_dispatch=pre_dispatch,
_allow_fake_constant=_allow_fake_constant,
)(*example_fake_inputs)
except CondOpArgsMismatchError as e:
# Wrap the internal error to the user-facing error
raise UserError(UserErrorType.DYNAMIC_CONTROL_FLOW, str(e))
new_graph = FlattenInputOutputSignature(
graph,
flat_args,
matched_input_elements_positions,
matched_output_elements_positions,
example_fake_inputs,
fake_mode,
).transform()
# Store constraints and inputs as metadata for user passes, e.g. turn constraints to runtime check
new_graph.meta["input_shape_constraints"] = (
[constraint.serializable_spec for constraint in constraints]
if constraints
else []
)
def signature_to_fullargspec(sig: inspect.Signature):
# Get a list of Parameter objects from the Signature object
params = list(sig.parameters.values())
# Separate positional arguments, keyword-only arguments and varargs/varkw
args = [
p.name
for p in params
if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
]
kwonlyargs = [
p.name for p in params if p.kind == inspect.Parameter.KEYWORD_ONLY
]
varargs = next(
(p.name for p in params if p.kind == inspect.Parameter.VAR_POSITIONAL),
None,
)
varkw = next(
(p.name for p in params if p.kind == inspect.Parameter.VAR_KEYWORD),
None,
)
# Get default values for positional arguments and keyword-only arguments
defaults = tuple(
p.default
for p in params
if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
and p.default is not inspect.Parameter.empty
)
kwonlydefaults = {
p.name: p.default
for p in params
if p.kind == inspect.Parameter.KEYWORD_ONLY
and p.default is not inspect.Parameter.empty
}
# Get annotations for parameters and return value
annotations = {}
if sig.return_annotation:
annotations = {"return": sig.return_annotation}
for parameter in params:
annotations[parameter.name] = parameter.annotation
# Return a FullArgSpec object with the extracted attributes
return inspect.FullArgSpec(
args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations
)
# Make dynamo graph to have same input/output spec as user code
def argument_names(f: Callable[..., Any], *args, **kwargs) -> List[str]:
fullargspec = signature_to_fullargspec(original_signature)
# 1. Map `args` 1-to-1 to positional arguments in original signature.
input_strs = fullargspec.args[: len(args)]
if len(args) > len(fullargspec.args):
# 2. If there are more arguments left in `args`, they map to varargs in original
# signature. Assign names as {varargs}_0, {varargs}_1, ...
assert fullargspec.varargs is not None, "More arguments than expected"
input_strs += [
f"{fullargspec.varargs}_{i}"
for i in range(0, len(args) - len(input_strs))
]
elif len(args) < len(fullargspec.args):
# 3. If there are fewer arguments in `args` than `fullargspec.args`,
# it implies these are arguments either with default values, or provided in
# `kwargs`. The former can be safely ignored. Because Dynamo.export does not
# export them as part of the function signature. The latter will be handled
# in the next step.
for unprovided_arg in fullargspec.args[
len(args) : -len(fullargspec.defaults or [])
]:
assert (
unprovided_arg in kwargs
), f"Missing argument {unprovided_arg}"
# 4. Keyword arguments provided in `kwargs`.
input_strs += list(kwargs.keys())
# 5. Keyword-only arguments with default values if not provided are not exported
# as part of the function signature.
for kwonly_arg in fullargspec.kwonlyargs:
kwonlydefaults = fullargspec.kwonlydefaults or {}
assert (
kwonly_arg in kwargs or kwonly_arg in kwonlydefaults
), f"Missing keyword only argument {kwonly_arg}"
return input_strs
new_graph.graph._codegen = _PyTreeCodeGen(
_PyTreeInfo(
argument_names(f, *args, **kwargs),
in_spec,
out_spec_traced,
)
)
new_graph.recompile()
return (new_graph, out_guards)
if extra_args or extra_kwargs:
warnings.warn(
"export(f, *args, **kwargs) is deprecated, use export(f)(*args, **kwargs) instead. "
"If you don't migrate, we may break your export call in the future if your user defined kwargs "
"conflict with future kwargs added to export(f)."
)
return inner(*extra_args, **extra_kwargs)
else:
return inner
def optimize_assert(

View file

@ -62,7 +62,6 @@ from .passes.add_runtime_assertions_for_constraints_pass import _AddRuntimeAsser
#
# result = torch._dynamo.export(
# my_model,
# *sixtyfour_tensors,
# constraints=[
# # if you do only dynamic_dim, this is sugar for
# # -Inf <= dynamic_dim(blah, 0) <= Inf; we dont otherwise
@ -74,6 +73,8 @@ from .passes.add_runtime_assertions_for_constraints_pass import _AddRuntimeAsser
# # NB: But we actually truncate ranges to be >= 2, because of
# # 0/1 specialization
# ]
# )(
# *sixtyfour_tensors,
# )
def dynamic_dim(t: torch.Tensor, index: int):
if not isinstance(t, torch.Tensor):
@ -152,10 +153,11 @@ def export(
try:
gm_torch_level, _ = torch._dynamo.export(
f,
*args,
constraints=constraints,
assume_static_by_default=True,
tracing_mode="symbolic",
)(
*args,
**kwargs,
)

View file

@ -141,9 +141,10 @@ def get_aten_graph_module(
import torch._dynamo
aten_pattern, _ = torch._dynamo.export(
pattern,
*copy.deepcopy(example_inputs),
aten_graph=True,
tracing_mode="real",
)(
*copy.deepcopy(example_inputs),
**kwargs,
)
aten_pattern.graph.eliminate_dead_code()

View file

@ -64,7 +64,7 @@ def _mark_nodes_as_annotated(nodes: List[Node]):
def _get_dynamo_graph(function: Callable, inputs) -> torch.fx.Graph:
gm, _ = torchdynamo.export(function, *inputs, aten_graph=True)
gm, _ = torchdynamo.export(function, aten_graph=True)(*inputs)
gm.graph.eliminate_dead_code()
return gm.graph

View file

@ -191,9 +191,10 @@ class DynamoExport(exporter.FXGraphExtractor):
fx_mode = "symbolic" if options.dynamic_shapes else "fake"
graph_module, graph_guard = torch._dynamo.export(
wrapped_model,
*model_args,
tracing_mode=fx_mode,
fake_mode=fake_mode, # type: ignore[arg-type]
)(
*model_args,
**model_kwargs,
)
del graph_guard # Unused

View file

@ -798,7 +798,7 @@ class Modularize(_pass.Transform):
>>> out = self.linear(out)
>>> return out
>>>
>>> gm, _ = torch._dynamo.export(TestModule(), torch.tensor([0, 1, 2]), aten_graph=True)
>>> gm, _ = torch._dynamo.export(TestModule(), aten_graph=True)(torch.tensor([0, 1, 2]))
>>> gm.print_readable()
>>> gm = passes.Modularize(infra.DiagnosticContext("test_context", "1.0"), gm).run()