From 7b9d250f069deca73dfb40fddebd08beb2888127 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Thu, 27 Jul 2023 14:39:15 -0700 Subject: [PATCH] Change _dynamo.export to be export(f)(*args, **kwargs) (#106109) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/106109 Approved by: https://github.com/voznesenskym --- test/cpp/aot_inductor/test.py | 2 +- test/dynamo/test_aot_autograd.py | 4 +- test/dynamo/test_ctx_manager.py | 10 +- test/dynamo/test_export.py | 316 +++++---- test/dynamo/test_export_mutations.py | 4 +- test/dynamo/test_misc.py | 20 +- test/dynamo/test_modules.py | 2 +- test/dynamo/test_repros.py | 22 +- test/fx/test_source_matcher_utils.py | 8 +- test/onnx/test_fx_passes.py | 4 +- torch/_dynamo/eval_frame.py | 616 +++++++++--------- torch/_export/__init__.py | 6 +- torch/ao/quantization/pt2e/utils.py | 3 +- .../quantizer/xnnpack_quantizer.py | 2 +- .../_internal/fx/dynamo_graph_extractor.py | 3 +- .../_internal/fx/passes/modularization.py | 2 +- 16 files changed, 521 insertions(+), 503 deletions(-) diff --git a/test/cpp/aot_inductor/test.py b/test/cpp/aot_inductor/test.py index b5a623c13ab..20fbe23c29b 100644 --- a/test/cpp/aot_inductor/test.py +++ b/test/cpp/aot_inductor/test.py @@ -23,7 +23,7 @@ with torch._dynamo.config.patch(dynamic_shapes=False): torch._dynamo.reset() with torch.no_grad(): - module, _ = torch._dynamo.export(Net().cuda(), x, y) + module, _ = torch._dynamo.export(Net().cuda())(x, y) lib_path = torch._inductor.aot_compile(module, [x, y]) shutil.copy(lib_path, "libaot_inductor_output.so") diff --git a/test/dynamo/test_aot_autograd.py b/test/dynamo/test_aot_autograd.py index e0a545d7b68..cd091f95d76 100644 --- a/test/dynamo/test_aot_autograd.py +++ b/test/dynamo/test_aot_autograd.py @@ -154,7 +154,7 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase): real = mod(rx) # Run it in export - graph, _ = torch._dynamo.export(mod, rx) + graph, _ = torch._dynamo.export(mod)(rx) # Run exported graph with AOT self.assertTrue(torch._dynamo.testing.same(real, graph(rx))) @@ -185,7 +185,7 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase): real = mod(x, y) # Run it in export - graph, _ = torch._dynamo.export(mod, x, y) + graph, _ = torch._dynamo.export(mod)(x, y) # Assert equal self.assertTrue(torch._dynamo.testing.same(real, graph(x, y))) diff --git a/test/dynamo/test_ctx_manager.py b/test/dynamo/test_ctx_manager.py index cf07046ab1d..5a0d81145fc 100644 --- a/test/dynamo/test_ctx_manager.py +++ b/test/dynamo/test_ctx_manager.py @@ -238,7 +238,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase): real_device = real.device real_dtype = real.dtype - graph, guards = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]])) + graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]])) exported = graph(torch.tensor([0.5])) self.assertEqual(exported.device, real_device) self.assertEqual(exported.dtype, real_dtype) @@ -263,7 +263,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase): real_device = real.device real_dtype = real.dtype - graph, _ = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]])) + graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]])) exported = graph(torch.tensor([0.5])) self.assertEqual(exported.device, real_device) self.assertEqual(exported.dtype, real_dtype) @@ -347,7 +347,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase): real_device = real.device real_dtype = real.dtype - graph, guards = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]])) + graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]])) exported = graph(torch.tensor([0.5])) self.assertEqual(exported.device, real_device) self.assertEqual(exported.dtype, real_dtype) @@ -521,7 +521,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase): real_device = real.device real_dtype = real.dtype - graph, guards = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]])) + graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]])) exported = graph(torch.tensor([0.5])) self.assertEqual(exported.device, real_device) self.assertEqual(exported.dtype, real_dtype) @@ -547,7 +547,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase): real_device = real.device real_dtype = real.dtype - graph, guards = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]])) + graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]])) exported = graph(torch.tensor([0.5])) self.assertEqual(exported.device, real_device) self.assertEqual(exported.dtype, real_dtype) diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index 35052b87f99..2f71fba5494 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -65,7 +65,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(func) + exported = torch._dynamo.export(func)() out_graph = exported[0] dynamo_result = out_graph() @@ -81,7 +81,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(func, torch.tensor([[[1.3737, 0.1]]])) + exported = torch._dynamo.export(func)(torch.tensor([[[1.3737, 0.1]]])) out_graph = exported[0] dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]])) @@ -99,7 +99,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(func, torch.ones(6, 4)) + exported = torch._dynamo.export(func)(torch.ones(6, 4)) out_graph, out_guards = exported dynamo_result = out_graph(torch.ones(6, 4)) @@ -133,7 +133,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): module = MyModule("moo") input = (torch.ones(4, 3),) resA = module(*input) - graph, _ = torch._dynamo.export(module, *input) + graph, _ = torch._dynamo.export(module)(*input) resB = graph(*input) self.assertTrue(torch._dynamo.utils.same(resA, resB)) @@ -154,7 +154,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(func, inp) + exported = torch._dynamo.export(func)(inp) out_graph = exported[0] dynamo_result = out_graph(inp) @@ -178,7 +178,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(func, inp) + exported = torch._dynamo.export(func)(inp) out_graph = exported[0] dynamo_result = out_graph(inp) @@ -191,7 +191,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): return y[0] + y[1] inp = [torch.tensor([1.3, 3.77, 0.1]), torch.tensor([8.7, 6.23, 9.9])] - gm, _ = torch._dynamo.export(f, inp, aten_graph=True, tracing_mode="symbolic") + gm, _ = torch._dynamo.export(f, aten_graph=True, tracing_mode="symbolic")(inp) self.assertTrue(torch._dynamo.utils.same(gm(inp), f(inp))) def test_export_with_shallow_list_copy_with_side_effects(self): @@ -202,7 +202,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): return x[0] + x[1], y[0] + y[1], y[2] inp = [torch.tensor([1.3, 3.77, 0.1]), torch.tensor([8.7, 6.23, 9.9])] - gm, _ = torch._dynamo.export(f, inp, aten_graph=True, tracing_mode="symbolic") + gm, _ = torch._dynamo.export(f, aten_graph=True, tracing_mode="symbolic")(inp) res = gm(inp) ref = f(inp) self.assertTrue(torch._dynamo.utils.same(res, ref)) @@ -218,7 +218,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(func, torch.tensor([[[1.3737, 0.1]]])) + exported = torch._dynamo.export(func)(torch.tensor([[[1.3737, 0.1]]])) out_graph = exported[0] dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]])) @@ -243,7 +243,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(func, inp) + exported = torch._dynamo.export(func)(inp) out_graph = exported[0] dynamo_result = out_graph(inp) @@ -269,7 +269,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(func, inp) + exported = torch._dynamo.export(func)(inp) out_graph = exported[0] dynamo_result = out_graph(inp) @@ -288,7 +288,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(func, inp) + exported = torch._dynamo.export(func)(inp) out_graph = exported[0] dynamo_result = out_graph(inp) @@ -307,7 +307,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(func, inp) + exported = torch._dynamo.export(func)(inp) out_graph = exported[0] dynamo_result = out_graph(inp) @@ -328,7 +328,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(func, *inps) + exported = torch._dynamo.export(func)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps) @@ -350,7 +350,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(func, *inps) + exported = torch._dynamo.export(func)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps) @@ -372,7 +372,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(func, *inps) + exported = torch._dynamo.export(func)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps) @@ -395,7 +395,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(func, *inps) + exported = torch._dynamo.export(func)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps) @@ -418,7 +418,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(func, *inps) + exported = torch._dynamo.export(func)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps_rand) @@ -442,7 +442,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(func, *inps) + exported = torch._dynamo.export(func)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps_rand) @@ -466,7 +466,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(func, *inps) + exported = torch._dynamo.export(func)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps_rand) @@ -490,7 +490,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(func, *inps) + exported = torch._dynamo.export(func)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps_rand) @@ -518,7 +518,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(func, *inps) + exported = torch._dynamo.export(func)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps_rand) @@ -542,7 +542,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(func, *inps) + exported = torch._dynamo.export(func)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps_rand) @@ -584,7 +584,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(func, aten_graph=True) + exported = torch._dynamo.export(func, aten_graph=True)() out_graph = exported[0] dynamo_result = out_graph() @@ -600,8 +600,8 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export( - func, torch.tensor([[[1.3737, 0.1]]]), aten_graph=True + exported = torch._dynamo.export(func, aten_graph=True)( + torch.tensor([[[1.3737, 0.1]]]) ) out_graph = exported[0] @@ -626,7 +626,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(func, inp, aten_graph=True) + exported = torch._dynamo.export(func, aten_graph=True)(inp) out_graph = exported[0] dynamo_result = out_graph(inp) @@ -650,7 +650,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(func, inp, aten_graph=True) + exported = torch._dynamo.export(func, aten_graph=True)(inp) out_graph = exported[0] dynamo_result = out_graph(inp) @@ -667,8 +667,8 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export( - func, torch.tensor([[[1.3737, 0.1]]]), aten_graph=True + exported = torch._dynamo.export(func, aten_graph=True)( + torch.tensor([[[1.3737, 0.1]]]) ) out_graph = exported[0] @@ -694,7 +694,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(func, inp, aten_graph=True) + exported = torch._dynamo.export(func, aten_graph=True)(inp) out_graph = exported[0] dynamo_result = out_graph(inp) @@ -720,7 +720,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(func, inp, aten_graph=True) + exported = torch._dynamo.export(func, aten_graph=True)(inp) out_graph = exported[0] dynamo_result = out_graph(inp) @@ -739,7 +739,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(func, inp, aten_graph=True) + exported = torch._dynamo.export(func, aten_graph=True)(inp) out_graph = exported[0] dynamo_result = out_graph(inp) @@ -758,7 +758,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(func, inp, aten_graph=True) + exported = torch._dynamo.export(func, aten_graph=True)(inp) out_graph = exported[0] dynamo_result = out_graph(inp) @@ -779,7 +779,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(func, *inps, aten_graph=True) + exported = torch._dynamo.export(func, aten_graph=True)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps) @@ -801,7 +801,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(func, *inps, aten_graph=True) + exported = torch._dynamo.export(func, aten_graph=True)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps) @@ -823,7 +823,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(func, *inps, aten_graph=True) + exported = torch._dynamo.export(func, aten_graph=True)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps) @@ -846,7 +846,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(func, *inps) + exported = torch._dynamo.export(func)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps) @@ -869,7 +869,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(func, *inps, aten_graph=True) + exported = torch._dynamo.export(func, aten_graph=True)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps_rand) @@ -897,7 +897,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(func, *inps, aten_graph=True) + exported = torch._dynamo.export(func, aten_graph=True)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps_rand) @@ -921,7 +921,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(func, *inps, aten_graph=True) + exported = torch._dynamo.export(func, aten_graph=True)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps_rand) @@ -945,7 +945,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): out = self.block(x) return out - exported = torch._dynamo.export(MyModule(), inp, aten_graph=False) + exported = torch._dynamo.export(MyModule(), aten_graph=False)(inp) out_graph = exported[0] for node in out_graph.graph.nodes: @@ -956,7 +956,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(MyModule(), inp, aten_graph=True) + exported = torch._dynamo.export(MyModule(), aten_graph=True)(inp) out_graph = exported[0] for node in out_graph.graph.nodes: if node.op == "call_function": @@ -989,7 +989,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): return out m = MyModule() - exported = torch._dynamo.export(m, inp, aten_graph=False) + exported = torch._dynamo.export(m, aten_graph=False)(inp) out_graph = exported[0] attr_access_count = 0 @@ -1001,7 +1001,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(m, inp, aten_graph=True) + exported = torch._dynamo.export(m, aten_graph=True)(inp) out_graph = exported[0] attr_access_count = 0 @@ -1022,7 +1022,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): y = linear(y) return y - exported = torch._dynamo.export(func, inp, aten_graph=True) + exported = torch._dynamo.export(func, aten_graph=True)(inp) out_graph = exported[0] export_result = out_graph(inp) @@ -1068,7 +1068,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): module = MyModule() real_result = module(torch.tensor([[1.0, 0], [0, 0]])) module = MyModule() - graph, _ = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]])) + graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]])) result = graph(torch.tensor([[1.0, 0.0], [0, 0]])) self.assertTrue(torch._dynamo.utils.same(result, real_result)) result = graph(torch.tensor([[1, 0], [0.25, 0.25]])) @@ -1094,7 +1094,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): module = MyModule() real_result = module(torch.tensor([[1.0, 0], [0, 0]])) module = MyModule() - graph, _ = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]])) + graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]])) result = graph(torch.tensor([[1.0, 0.0], [0, 0]])) self.assertTrue(torch._dynamo.utils.same(result, real_result)) result = graph(torch.tensor([[1, 0], [0.25, 0.25]])) @@ -1124,7 +1124,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): module = MyModule() real_result = module(torch.tensor([[1.0, 0], [0, 0]])) module = MyModule() - graph, _ = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]])) + graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]])) result = graph(torch.tensor([[1.0, 0.0], [0, 0]])) self.assertTrue(torch._dynamo.utils.same(result, real_result)) result = graph(torch.tensor([[1, 0], [0.25, 0.25]])) @@ -1150,7 +1150,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): module = MyModule() real_result = module(torch.tensor([[1.0, 0], [0, 0]])) module = MyModule() - graph, _ = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]])) + graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]])) result = graph(torch.tensor([[1.0, 0.0], [0, 0]])) self.assertTrue(torch._dynamo.utils.same(result, real_result)) result = graph(torch.tensor([[1, 0], [0.25, 0.25]])) @@ -1178,8 +1178,8 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch.tensor([[1.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]]) ) module = MyModule() - graph, _ = torch._dynamo.export( - module, torch.tensor([[0.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]]) + graph, _ = torch._dynamo.export(module)( + torch.tensor([[0.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]]) ) result = graph( torch.tensor([[1.0, 0.0], [0, 0]]), torch.tensor([[1.0, 0.0], [0, 0]]) @@ -1205,8 +1205,8 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch.tensor([[1.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]]) ) module = MyModule() - graph, _ = torch._dynamo.export( - module, torch.tensor([[0.0, 0], [0, 0]]), torch.tensor([[0.0, 0], [0.5, 0]]) + graph, _ = torch._dynamo.export(module)( + torch.tensor([[0.0, 0], [0, 0]]), torch.tensor([[0.0, 0], [0.5, 0]]) ) result = graph( torch.tensor([[1.0, 0.0], [0, 0]]), torch.tensor([[0.0, 1.0], [0, 0]]) @@ -1235,7 +1235,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): module = MyModule() real_result = module(torch.tensor([1.0, 1.0])) - graph, guards = torch._dynamo.export(module, torch.tensor([1.0, 1.0])) + graph, guards = torch._dynamo.export(module)(torch.tensor([1.0, 1.0])) # Tensor input can be almost anything here, and the result will capture what we # made constant at compile time. @@ -1259,7 +1259,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): module = MyModule() real_result = module(torch.tensor([1.0, 1.0])) - graph, guards = torch._dynamo.export(module, torch.tensor([1.0, 1.0])) + graph, guards = torch._dynamo.export(module)(torch.tensor([1.0, 1.0])) # Tensor input can be almost anything here, and the result will capture what we # made constant at compile time. @@ -1283,7 +1283,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): module = MyModule() real_result = module(torch.tensor([1.0, 1.0])) - graph, guards = torch._dynamo.export(module, torch.tensor([1.0, 1.0])) + graph, guards = torch._dynamo.export(module)(torch.tensor([1.0, 1.0])) # Tensor input can be almost anything here, and the result will capture what we # made constant at compile time. @@ -1305,7 +1305,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): module = MyModule() real_result = module(torch.tensor([2.0, 2.0])) - graph, guards = torch._dynamo.export(module, torch.tensor([2.0, 2.0])) + graph, guards = torch._dynamo.export(module)(torch.tensor([2.0, 2.0])) # Tensor input can be almost anything here, and the result will capture what we # made constant at compile time. @@ -1334,7 +1334,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): # X is negative, so .item() < 0, which means we return y self.assertEqual(real_result, torch.tensor([0.5])) - graph, guards = torch._dynamo.export(module, torch.tensor([-1])) + graph, guards = torch._dynamo.export(module)(torch.tensor([-1])) result = graph(torch.tensor([2])) # X is positive, but we compiled helper_fn to return None, so it will still return y self.assertTrue(torch._dynamo.utils.same(result, real_result)) @@ -1361,7 +1361,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): # X is positive, so .item() > 0, which means we return y * x self.assertEqual(real_result, torch.tensor([1.0])) - graph, guards = torch._dynamo.export(module, torch.tensor([2])) + graph, guards = torch._dynamo.export(module)(torch.tensor([2])) result = graph(torch.tensor([-0.5])) # X is negative, but we compiled helper_fn to return x, so it will still return y * x self.assertTrue(torch._dynamo.utils.same(result, real_result)) @@ -1388,7 +1388,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): # X is negative, so .item() < 0, which means we return y self.assertEqual(real_result, torch.tensor([0.5])) - graph, guards = torch._dynamo.export(module, torch.tensor([-1])) + graph, guards = torch._dynamo.export(module)(torch.tensor([-1])) result = graph(torch.tensor([2])) # X is positive, but we compiled helper_fn to return None, so it will still return y self.assertTrue(torch._dynamo.utils.same(result, real_result)) @@ -1415,7 +1415,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): # X is positive, so .item() > 0, which means we return y * x self.assertEqual(real_result, torch.tensor([1.0])) - graph, guards = torch._dynamo.export(module, torch.tensor([2])) + graph, guards = torch._dynamo.export(module)(torch.tensor([2])) result = graph(torch.tensor([-0.5])) # X is negative, but we compiled helper_fn to return x, so it will still return y * x self.assertTrue(torch._dynamo.utils.same(result, real_result)) @@ -1442,7 +1442,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): # X is positive, so .item() > 0, which means we return y * x self.assertEqual(real_result, torch.tensor([1.0])) - graph, guards = torch._dynamo.export(module, torch.tensor([2])) + graph, guards = torch._dynamo.export(module)(torch.tensor([2])) result = graph(torch.tensor([-0.5])) # X is negative, but we compiled helper_fn to return x, so it will still return y * x self.assertTrue(torch._dynamo.utils.same(result, real_result)) @@ -1463,7 +1463,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): module = MyModule() module.val = "A" resA = module(torch.tensor([2])) - graph, guards = torch._dynamo.export(module, torch.tensor([2])) + graph, guards = torch._dynamo.export(module)(torch.tensor([2])) module.val = "B" resB = graph(torch.tensor([2])) self.assertTrue(torch._dynamo.utils.same(resA, resB)) @@ -1477,17 +1477,16 @@ class ExportTests(torch._dynamo.test_case.TestCase): graph, _ = torch._dynamo.export( f, - (torch.randn(5)), aten_graph=True, decomposition_table={torch.ops.aten.t.default: nop}, - ) + )(torch.randn(5)) self.assertEqual( len([n for n in graph.graph.nodes if n.target == torch.ops.aten.t.default]), 0, ) - graph, _ = torch._dynamo.export( - f, (torch.randn(5)), aten_graph=True, decomposition_table=None + graph, _ = torch._dynamo.export(f, aten_graph=True, decomposition_table=None)( + torch.randn(5) ) self.assertEqual( len([n for n in graph.graph.nodes if n.target == torch.ops.aten.t.default]), @@ -1534,7 +1533,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.reset() - exported = torch._dynamo.export(mod.forward, pred, x) + exported = torch._dynamo.export(mod.forward)(pred, x) out_graph = exported[0] dynamo_result = out_graph(pred, x) @@ -1576,7 +1575,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): x = torch.randn([3, 3]) pred = torch.tensor(x[0][0].item() < 0) real_result = mod.forward(pred, x) - out_graph, _ = torch._dynamo.export(mod.forward, pred, x) + out_graph, _ = torch._dynamo.export(mod.forward)(pred, x) dynamo_result = out_graph(pred, x) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) @@ -1631,7 +1630,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): x = torch.randn([3, 3]) pred = torch.tensor(x[0][0].item() < 0) real_result = mod.forward(pred, x) - out_graph, _ = torch._dynamo.export(mod.forward, pred, x) + out_graph, _ = torch._dynamo.export(mod.forward)(pred, x) dynamo_result = out_graph(pred, x) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) @@ -1654,7 +1653,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): x = torch.randn(5) pred = x[0] > 0 real_result = foo(pred, x) - out_graph, _ = torch._dynamo.export(foo, pred, x) + out_graph, _ = torch._dynamo.export(foo)(pred, x) dynamo_result = out_graph(pred, x) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) @@ -1673,7 +1672,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): mod = Module() x = torch.randn(2, 2) - out_graph, _ = torch._dynamo.export(mod, x) + out_graph, _ = torch._dynamo.export(mod)(x) test_x = torch.randn(3, 2) self.assertEqual(out_graph(test_x), mod(test_x)) @@ -1704,7 +1703,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): pred_y = torch.tensor(False) real_result = mod(pred_y, y) - out_graph, _ = torch._dynamo.export(mod, pred_x, x) + out_graph, _ = torch._dynamo.export(mod)(pred_x, x) self.assertEqual(real_result, out_graph(pred_y, y)) def test_export_with_map_zero_sized_tensor(self): @@ -1723,7 +1722,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): torch._dynamo.exc.Unsupported, "zero-sized tensor", ): - out_graph, _ = torch._dynamo.export(mod, xs) + out_graph, _ = torch._dynamo.export(mod)(xs) def test_export_meta_val(self): def f(x, y, z): @@ -1731,10 +1730,11 @@ class ExportTests(torch._dynamo.test_case.TestCase): gm, _ = torch._dynamo.export( f, + aten_graph=True, + )( torch.ones(3, 2), torch.zeros(3, 2), torch.ones(3, 2), - aten_graph=True, ) for node in gm.graph.nodes: if node.op == "placeholder": @@ -1746,7 +1746,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): inp = (torch.randn(6, 5), [torch.randn(6, 5), torch.randn(6, 5)]) - gm, _ = torch._dynamo.export(f, *inp, aten_graph=True) + gm, _ = torch._dynamo.export(f, aten_graph=True)(*inp) self.assertEqual(gm(*inp), f(*inp)) @@ -1756,7 +1756,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): return torch.empty(x.shape[0] * 2) inp = (torch.randn(6, 5),) - gm, _ = torch._dynamo.export(f, *inp, aten_graph=True) + gm, _ = torch._dynamo.export(f, aten_graph=True)(*inp) has_sym_size = False for node in gm.graph.nodes: @@ -1770,7 +1770,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): def f(x): return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2] - gm_aten_mode, _ = torch._dynamo.export(f, torch.randn(4, 5), aten_graph=True) + gm_aten_mode, _ = torch._dynamo.export(f, aten_graph=True)(torch.randn(4, 5)) inp = torch.randn(6, 7) self.assertEqual(gm_aten_mode(inp).shape, f(inp).shape) @@ -1787,7 +1787,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): self.assertEqual(count, 2) - gm_torch_mode, _ = torch._dynamo.export(f, torch.randn(4, 5), aten_graph=False) + gm_torch_mode, _ = torch._dynamo.export(f, aten_graph=False)(torch.randn(4, 5)) # In torch mode, the graph should contain 3 getitem methods # one for x.shape[0]-2 and one for x.shape[1]-1 and one for slice @@ -1811,9 +1811,10 @@ class ExportTests(torch._dynamo.test_case.TestCase): ): torch._dynamo.export( g, + aten_graph=True, + )( torch.randn(4, 5), torch.tensor(2), - aten_graph=True, ) @config.patch(capture_scalar_outputs=True) @@ -1821,7 +1822,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): def f(x): return x[slice(None, None, None)] - gm, _ = torch._dynamo.export(f, torch.randn(4, 5), aten_graph=True) + gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.randn(4, 5)) inp = torch.randn(6, 7) self.assertEqual(gm(inp), f(inp)) @@ -1833,10 +1834,11 @@ class ExportTests(torch._dynamo.test_case.TestCase): gm, _ = torch._dynamo.export( f, - torch.randn(5, 5), aten_graph=True, pre_dispatch=True, tracing_mode="fake", + )( + torch.randn(5, 5), ) inp = torch.randn(6, 6) @@ -1869,7 +1871,7 @@ def forward(self, x): torch.randn(4, 4), torch.randn(4, 4), ) - gm, _ = torch._dynamo.export(model, *inp, aten_graph=True) + gm, _ = torch._dynamo.export(model, aten_graph=True)(*inp) gm.print_readable() @@ -2163,7 +2165,7 @@ def forward(self, x): m = MyModule() inp = torch.ones(2, 3, device="meta") - exported = torch._dynamo.export(m, inp) + exported = torch._dynamo.export(m)(inp) out_graph = exported[0] dynamo_result = out_graph(inp) self.assertEqual(dynamo_result, m(inp)) @@ -2176,10 +2178,10 @@ def forward(self, x): return x.sin() return x.cos() - torch._dynamo.export(my_dyn_fn, y) + torch._dynamo.export(my_dyn_fn)(y) with self.assertRaises(ConstraintViolationError): - torch._dynamo.export(my_dyn_fn, y, constraints=[dynamic_dim(y, 0)]) + torch._dynamo.export(my_dyn_fn, constraints=[dynamic_dim(y, 0)])(y) def test_export_module_specify_constraints_signature(self): y = torch.randn([3, 3, 3]) @@ -2191,12 +2193,12 @@ def forward(self, x): return x.cos() mod = Mod() - torch._dynamo.export(mod, y) + torch._dynamo.export(mod)(y) with self.assertRaisesRegex( ConstraintViolationError, "def specify_constraints\\(x\\):" ): - torch._dynamo.export(mod, y, constraints=[dynamic_dim(y, 0)]) + torch._dynamo.export(mod, constraints=[dynamic_dim(y, 0)])(y) def test_export_raise_guard_partial_constraint(self): y = torch.randn([3, 3, 3]) @@ -2206,10 +2208,10 @@ def forward(self, x): return x.sin() return x.cos() - torch._dynamo.export(my_dyn_fn, y) + torch._dynamo.export(my_dyn_fn)(y) with self.assertRaises(ConstraintViolationError): - torch._dynamo.export(my_dyn_fn, y, constraints=[dynamic_dim(y, 0)]) + torch._dynamo.export(my_dyn_fn, constraints=[dynamic_dim(y, 0)])(y) def test_export_raise_on_relationship(self): y = torch.randn([3, 3, 3]) @@ -2220,15 +2222,15 @@ def forward(self, x): return a.cos() - torch._dynamo.export(my_dyn_fn, y, y, y) + torch._dynamo.export(my_dyn_fn)(y, y, y) constraints = [dynamic_dim(y, 0)] with self.assertRaises(ConstraintViolationError): - torch._dynamo.export(my_dyn_fn, y, y, y, constraints=constraints) + torch._dynamo.export(my_dyn_fn, constraints=constraints)(y, y, y) constraints += [ dynamic_dim(y, 1) == dynamic_dim(y, 0), dynamic_dim(y, 2) == dynamic_dim(y, 0), ] - torch._dynamo.export(my_dyn_fn, y, y, y, constraints=constraints) + torch._dynamo.export(my_dyn_fn, constraints=constraints)(y, y, y) def test_export_no_raise(self): y = torch.randn([3, 3, 3]) @@ -2238,8 +2240,8 @@ def forward(self, x): return a.cos() return a * b * c - torch._dynamo.export(my_dyn_fn, y, y, y) - torch._dynamo.export(my_dyn_fn, y, y, y, constraints=[dynamic_dim(y, 0)]) + torch._dynamo.export(my_dyn_fn)(y, y, y) + torch._dynamo.export(my_dyn_fn, constraints=[dynamic_dim(y, 0)])(y, y, y) def test_export_multi_dynamic_dim_unsafe_relationship(self): x = torch.randn([3, 3, 3]) @@ -2251,12 +2253,12 @@ def forward(self, x): return a.cos() return a * c, b - torch._dynamo.export(my_dyn_fn, x, y, z) + torch._dynamo.export(my_dyn_fn)(x, y, z) constraints = [dynamic_dim(x, 0), dynamic_dim(y, 0), dynamic_dim(z, 0)] with self.assertRaises(ConstraintViolationError): - torch._dynamo.export(my_dyn_fn, x, y, z, constraints=constraints) + torch._dynamo.export(my_dyn_fn, constraints=constraints)(x, y, z) constraints.append(dynamic_dim(z, 0) == dynamic_dim(x, 0)) - torch._dynamo.export(my_dyn_fn, x, y, z, constraints=constraints) + torch._dynamo.export(my_dyn_fn, constraints=constraints)(x, y, z) @config.patch( capture_dynamic_output_shape_ops=True, @@ -2275,11 +2277,10 @@ def forward(self, x): constraints = [dynamic_dim(y, 0) >= 6, dynamic_dim(y, 0) <= 10] gm, _ = torch._dynamo.export( f, - *example_inputs, constraints=constraints, aten_graph=True, tracing_mode="symbolic", - ) + )(*example_inputs) self.assertEqual( gm.meta["input_shape_constraints"], @@ -2301,11 +2302,10 @@ def forward(self, x): constraints = [] gm, _ = torch._dynamo.export( f, - y, constraints=constraints, aten_graph=True, tracing_mode="symbolic", - ) + )(y) @config.patch( capture_dynamic_output_shape_ops=True, @@ -2326,11 +2326,10 @@ def forward(self, x): constraints = [dynamic_dim(y, 0) >= 6, dynamic_dim(y, 0) <= 10] gm, _ = torch._dynamo.export( f, - *example_inputs, constraints=constraints, aten_graph=True, tracing_mode="symbolic", - ) + )(*example_inputs) # Ensure the exported graph module with metadata is serializable, # metadata won't be saved in the serialized module @@ -2383,16 +2382,16 @@ def forward(self, x): return a.cos() return a * a - torch._dynamo.export(my_dyn_fn, x) + torch._dynamo.export(my_dyn_fn)(x) with self.assertRaises(ConstraintViolationError): - torch._dynamo.export(my_dyn_fn, x, constraints=[dynamic_dim(x, 0)]) + torch._dynamo.export(my_dyn_fn, constraints=[dynamic_dim(x, 0)])(x) def test_symbool(self): def f(x): a = torch.scalar_tensor(x.shape[0] > 4) return x.sin().sum() + a.sum() - gm, _ = torch._dynamo.export(f, torch.ones(6, 4), aten_graph=True) + gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(6, 4)) self.assertEqual(gm(torch.ones(3, 4)), f(torch.ones(3, 4))) def test_export_multi_dynamic_dim_constraint(self): @@ -2405,12 +2404,12 @@ def forward(self, x): return a.cos() return a * c, b - torch._dynamo.export(my_dyn_fn, x, y, z) + torch._dynamo.export(my_dyn_fn)(x, y, z) constraints = [dynamic_dim(x, 0), dynamic_dim(x, 1), dynamic_dim(x, 2)] with self.assertRaises(ConstraintViolationError): - torch._dynamo.export(my_dyn_fn, x, y, z, constraints=constraints) + torch._dynamo.export(my_dyn_fn, constraints=constraints)(x, y, z) constraints.append(dynamic_dim(z, 0) == dynamic_dim(x, 0)) - torch._dynamo.export(my_dyn_fn, x, y, z, constraints=constraints) + torch._dynamo.export(my_dyn_fn, constraints=constraints)(x, y, z) def test_export_dynamic_dim_raise_on_compound_range_constraint(self): x = torch.ones(6, 4, 4) @@ -2431,10 +2430,9 @@ def forward(self, x): torch._dynamo.export( foo, - x, constraints=constraints, aten_graph=True, - ) + )(x) def bar(x): if x.shape[0] > 5: # error @@ -2444,10 +2442,9 @@ def forward(self, x): with self.assertRaises(ConstraintViolationError): torch._dynamo.export( bar, - x, constraints=constraints, aten_graph=True, - ) + )(x) def test_list_contains(self): def func(x): @@ -2460,7 +2457,7 @@ def forward(self, x): torch._dynamo.reset() - exported = torch._dynamo.export(func, *inps, aten_graph=True) + exported = torch._dynamo.export(func, aten_graph=True)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps) @@ -2479,7 +2476,7 @@ def forward(self, x): torch._dynamo.reset() - exported = torch._dynamo.export(func, *inps, aten_graph=True) + exported = torch._dynamo.export(func, aten_graph=True)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps) @@ -2493,7 +2490,7 @@ def forward(self, x): return x torch._dynamo.reset() - exported, _ = torch._dynamo.export(func, inp) + exported, _ = torch._dynamo.export(func)(inp) dynamo_result = exported(inp) self.assertTrue(torch._dynamo.utils.same(inp, dynamo_result)) @@ -2516,7 +2513,7 @@ def forward(self, x): inp = torch.randn(3, 128) # In export, int & float in forward should always be specialized - gm, _ = torch._dynamo.export(mod, inp, aten_graph=True) + gm, _ = torch._dynamo.export(mod, aten_graph=True)(inp) count = 0 for node in gm.graph.nodes: if node.op == "placeholder": @@ -2536,7 +2533,7 @@ def forward(self, x): static_sizes = 3, 4 for input_tensor, static_size in zip(input_tensors, static_sizes): m = BasicModule(static_size) - gm, _ = torch._dynamo.export(m, input_tensor, aten_graph=True) + gm, _ = torch._dynamo.export(m, aten_graph=True)(input_tensor) res = gm(input_tensor) self.assertEqual(res.size(0), static_size) self.assertTrue( @@ -2555,7 +2552,7 @@ def forward(self, x): return self.my_lin(x) mod, input_tensor = BasicModule(), torch.randn(2, 3) - gm, guard = torch._dynamo.export(mod, input_tensor, aten_graph=True) + gm, guard = torch._dynamo.export(mod, aten_graph=True)(input_tensor) ref = mod(x=input_tensor) res = gm(x=input_tensor) self.assertTrue(torch._dynamo.utils.same(ref, res)) @@ -2574,8 +2571,8 @@ def forward(self, x): torch.randn(2, 3), torch.randn(2, 3), ) - gm, guard = torch._dynamo.export( - mod, input_tensor, input_tensor2, aten_graph=True + gm, guard = torch._dynamo.export(mod, aten_graph=True)( + input_tensor, input_tensor2 ) ref = mod(input_tensor, input_tensor2) res = gm(input_tensor, input_tensor2) @@ -2594,7 +2591,7 @@ def forward(self, x): RuntimeError, "Constraints violated", ): - torch._dynamo.export(my_dyn_fn, y, constraints=[dynamic_dim(y, 0)]) + torch._dynamo.export(my_dyn_fn, constraints=[dynamic_dim(y, 0)])(y) def test_export_dynamic_dim_cleanup(self): y = torch.randn([3, 3, 3]) @@ -2603,7 +2600,7 @@ def forward(self, x): return x.cos() constraints = [dynamic_dim(y, 0)] - torch._dynamo.export(my_dyn_fn, y, constraints=constraints) + torch._dynamo.export(my_dyn_fn, constraints=constraints)(y) @config.patch(capture_dynamic_output_shape_ops=True) def test_export_dynamic_control_flow_error(self): @@ -2616,7 +2613,7 @@ def forward(self, x): torch._dynamo.exc.UserError, "Dynamic control flow is not supported at the moment", ): - gm, _ = torch._dynamo.export(f, torch.randn(5, 6), aten_graph=True) + gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.randn(5, 6)) @config.patch(assume_static_by_default=False) def test_export_persist_assert(self): @@ -2624,8 +2621,8 @@ def forward(self, x): assert x.shape[0] > 4, "Shape must be more than 4" return x.cos() + x.sin() - gm, guard = torch._dynamo.export( - f, torch.randn(5, 4, 6), aten_graph=True, tracing_mode="symbolic" + gm, guard = torch._dynamo.export(f, aten_graph=True, tracing_mode="symbolic")( + torch.randn(5, 4, 6) ) def has_aten_op(gm, op): @@ -2654,13 +2651,13 @@ def forward(self, x): return x.sum() + type(a).func().sum() with self.assertRaisesRegex(torch._dynamo.exc.UserError, "Can't call type()"): - gm, _ = torch._dynamo.export(f, torch.ones(6, 4), aten_graph=True) + gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(6, 4)) def f_correct(x): a = A() return x.sum() + a.__class__.func().sum() - gm, _ = torch._dynamo.export(f_correct, torch.ones(6, 4), aten_graph=True) + gm, _ = torch._dynamo.export(f_correct, aten_graph=True)(torch.ones(6, 4)) self.assertEqual(f_correct(torch.ones(6, 4)), gm(torch.ones(6, 4))) @@ -2677,10 +2674,9 @@ def forward(self, x): example_inputs = (torch.ones(1, 2, 3),) gm, _ = torch._dynamo.export( Foo(), - *example_inputs, aten_graph=True, tracing_mode="symbolic", - ) + )(*example_inputs) count = 0 for node in gm.graph.nodes: if node.target == torch.ops.aten.add_.Tensor: @@ -2698,9 +2694,9 @@ def forward(self, x): return x[: math.floor(x.shape[0] / 2)] with self.assertRaisesRegex(torch._dynamo.exc.UserError, "Calling round()"): - gm, _ = torch._dynamo.export(f, torch.ones(6, 4), aten_graph=True) + gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(6, 4)) - gm, _ = torch._dynamo.export(f_correct, torch.ones(6, 4), aten_graph=True) + gm, _ = torch._dynamo.export(f_correct, aten_graph=True)(torch.ones(6, 4)) self.assertEqual(f_correct(torch.ones(6, 4)), gm(torch.ones(6, 4))) @@ -2731,7 +2727,7 @@ def forward(self, x): f_pred_traced_as_tensor_var, f_pred_complex_expression_traced_as_symnode_var, ]: - gm, _ = torch._dynamo.export(f, *example_inputs, aten_graph=True) + gm, _ = torch._dynamo.export(f, aten_graph=True)(*example_inputs) self.assertEqual(gm(*example_inputs), f(*example_inputs)) def test_mixed_real_and_fake_inputs(self): @@ -2759,9 +2755,8 @@ def forward(self, x): example_inputs = (torch.randn(1, 1, 3, 3),) torch._dynamo.export( _TestPattern(), - *example_inputs, aten_graph=True, - ) + )(*example_inputs) @config.patch( capture_dynamic_output_shape_ops=True, @@ -2772,7 +2767,7 @@ def forward(self, x): def f(x, y): return x.size(0) in y - gm, _ = torch._dynamo.export(f, torch.ones(2), torch.ones(3), aten_graph=True) + gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(2), torch.ones(3)) true_inp = (torch.Tensor([6, 4, 5]), torch.ones(6, 4).add_(5)) false_inp = (torch.Tensor([6, 4, 5]), torch.ones(6, 4).add_(2)) @@ -2794,7 +2789,7 @@ def forward(self, x): torch._dynamo.exc.UserError, "Expected 4 arguments", ): - torch._dynamo.export(f, *example_inputs, aten_graph=True) + torch._dynamo.export(f, aten_graph=True)(*example_inputs) def test_cond_raise_user_error_on_unsupported_pred(self): def f_unsupported_pred(x): @@ -2808,9 +2803,8 @@ def forward(self, x): ): torch._dynamo.export( f_unsupported_pred, - *example_inputs, aten_graph=True, - ) + )(*example_inputs) def test_cond_raise_user_error_on_non_list_operands(self): def f_non_list_operands(x): @@ -2823,9 +2817,8 @@ def forward(self, x): ): torch._dynamo.export( f_non_list_operands, - *example_inputs, aten_graph=True, - ) + )(*example_inputs) def test_cond_raise_user_error_on_non_tensor_operands(self): def f_non_tensor_operands(x): @@ -2841,9 +2834,8 @@ def forward(self, x): ): torch._dynamo.export( f_non_tensor_operands, - *example_inputs, aten_graph=True, - ) + )(*example_inputs) def test_cond_raise_user_error_on_branch_args_mismatch(self): def true_fn(x, y): @@ -2862,8 +2854,9 @@ def forward(self, x): ): torch._dynamo.export( f_branch_args_mismatch, - *example_inputs, aten_graph=True, + )( + *example_inputs, ) def test_cond_raise_user_error_on_branch_return_non_tensor(self): @@ -2877,9 +2870,8 @@ def forward(self, x): ): torch._dynamo.export( f_branch_return_non_tensor, - *example_inputs, aten_graph=True, - ) + )(*example_inputs) def test_cond_raise_user_error_on_branch_return_multiple_tenors(self): def f_branch_return_multiple_tensors(x, y): @@ -2892,9 +2884,8 @@ def forward(self, x): ): torch._dynamo.export( f_branch_return_multiple_tensors, - *example_inputs, aten_graph=True, - ) + )(*example_inputs) def test_multiple_outputs_op_with_evaluator(self): class TopKModel(torch.nn.Module): @@ -2903,7 +2894,7 @@ def forward(self, x): return torch.sum(values) x = torch.arange(1.0, 6.0, requires_grad=True) - torch._dynamo.export(TopKModel(), x) + torch._dynamo.export(TopKModel())(x) def test_cond_raise_user_error_on_mismatch_return_length(self): def true_fn(x): @@ -2922,9 +2913,8 @@ def forward(self, x): ): torch._dynamo.export( f_mismatch_return_length, - *example_inputs, aten_graph=True, - ) + )(*example_inputs) def test_cond_raise_user_error_on_mismatch_return_tensor_meta(self): def true_fn(x): @@ -2941,10 +2931,8 @@ def forward(self, x): torch._dynamo.exc.UserError, "Expected each tensor to have same metadata but got", ): - torch._dynamo.export( - f_return_tensor_mismatch, + torch._dynamo.export(f_return_tensor_mismatch, aten_graph=True)( *example_inputs, - aten_graph=True, ) def test_byte_tensor_does_not_crash(self): @@ -2967,10 +2955,8 @@ def forward(self, x): results.append(x[: x.size(0) - i, i : x.size(2), i:3]) return tuple(results) - gm, _ = torch._dynamo.export( - DynamicSliceExportMod(), + gm, _ = torch._dynamo.export(DynamicSliceExportMod(), aten_graph=True)( torch.randn(5, 5, 5), - aten_graph=True, ) self.assertExpectedInline( @@ -3015,7 +3001,7 @@ def forward(self, x): x = torch.randn(3) for aten_graph in [True, False]: - gm, _ = torch._dynamo.export(f, x, aten_graph=aten_graph) + gm, _ = torch._dynamo.export(f, aten_graph=aten_graph)(x) self.assertTrue( isinstance(gm, torch.fx.GraphModule), msg="test_capture_symbolic_tracing_simple_within_fake_mode_aten_graph_" @@ -3097,7 +3083,7 @@ def forward(self, x): # Export the model with fake inputs and parameters for aten_graph in [True, False]: - graph_module, _ = torch._dynamo.export(model, x, aten_graph=aten_graph) + graph_module, _ = torch._dynamo.export(model, aten_graph=aten_graph)(x) self.assertTrue( isinstance(graph_module, torch.fx.GraphModule), msg="test_capture_symbolic_tracing_within_fake_mode_aten_graph_" @@ -3136,7 +3122,7 @@ def forward(self, x): return (cond(x.shape[0] > 4, true_fn, false_fn, [x]),) - gm, _ = torch._dynamo.export(M(), torch.ones(6, 4), aten_graph=False) + gm, _ = torch._dynamo.export(M(), aten_graph=False)(torch.ones(6, 4)) self.assertEqual(gm(torch.ones(6, 4)), M()(torch.ones(6, 4))) self.assertEqual(gm(torch.ones(3, 4)), M()(torch.ones(3, 4))) @@ -3178,7 +3164,7 @@ def forward(self, x): return (cond(x.shape[0] > 4, true_fn, false_fn, [x]),) - gm, _ = torch._dynamo.export(M(), torch.ones(6, 4), aten_graph=False) + gm, _ = torch._dynamo.export(M(), aten_graph=False)(torch.ones(6, 4)) self.assertEqual(gm(torch.ones(6, 4)), M()(torch.ones(6, 4))) self.assertEqual(gm(torch.ones(5, 4)), M()(torch.ones(5, 4))) self.assertEqual(gm(torch.ones(3, 4)), M()(torch.ones(3, 4))) @@ -3231,7 +3217,7 @@ def forward(self, x): pred_y = torch.tensor(False) real_result = mod(pred_y, y) - out_graph, _ = torch._dynamo.export(mod, pred_x, x) + out_graph, _ = torch._dynamo.export(mod)(pred_x, x) self.assertEqual(real_result, out_graph(pred_y, y)) def test_cond_free_variables_overlapping(self): @@ -3259,7 +3245,7 @@ def forward(self, x): x = torch.ones(6, 4) pred_x = torch.tensor(True) - out_graph, _ = torch._dynamo.export(mod, pred_x, x) + out_graph, _ = torch._dynamo.export(mod)(pred_x, x) self.assertExpectedInline( out_graph.code.strip(), """\ diff --git a/test/dynamo/test_export_mutations.py b/test/dynamo/test_export_mutations.py index fbe8759ea9c..c3b73a81783 100644 --- a/test/dynamo/test_export_mutations.py +++ b/test/dynamo/test_export_mutations.py @@ -10,11 +10,11 @@ from torch.testing._internal.common_utils import IS_FBCODE class MutationExportTests(torch._dynamo.test_case.TestCase): def check_failure_on_export(self, mod, *args): with self.assertRaises(AssertionError): - torch._dynamo.export(mod, *args) + torch._dynamo.export(mod)(*args) def check_same_with_export(self, mod, arg): real_result = mod(arg) - graph, _ = torch._dynamo.export(mod, arg) + graph, _ = torch._dynamo.export(mod)(arg) result = graph(arg) self.assertTrue(torch._dynamo.utils.same(result, real_result)) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 07459496bd2..c3f13fef89f 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -3243,7 +3243,7 @@ def fn(): def f(x): return 1 + torch._shape_as_tensor(x)[0] - gm, _ = torch._dynamo.export(f, torch.ones(6)) + gm, _ = torch._dynamo.export(f)(torch.ones(6)) input_one_dim = torch.ones(6) input_two_dims = torch.ones(7, 4) @@ -3583,8 +3583,8 @@ def fn(): def f(pred, pred2, x): return cond(pred, true_fn, false_fn, [pred2, x]) - graph, guard = torch._dynamo.export( - f, torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25]) + graph, guard = torch._dynamo.export(f)( + torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25]) ) true_true_sin = graph( torch.tensor(True), torch.tensor(True), torch.tensor([0.25, 0.25]) @@ -3622,8 +3622,8 @@ def fn(): def f(pred, x): return cond(pred, true_fn, false_fn, [x]) - graph, guard = torch._dynamo.export( - f, torch.tensor(False), torch.tensor([0.25, 0.25]) + graph, guard = torch._dynamo.export(f)( + torch.tensor(False), torch.tensor([0.25, 0.25]) ) true_mirror = graph(torch.tensor(True), torch.tensor([0.25, 0.25])) self.assertTrue(same(torch.tensor([0.25, 0.25]), true_mirror)) @@ -3891,7 +3891,7 @@ def fn(): return self.mod[0](x) m = Mod() - graph, _ = torch._dynamo.export(m, torch.randn(3, 3)) + graph, _ = torch._dynamo.export(m)(torch.randn(3, 3)) def test_nn_sequential_invocation(self): with freeze_rng_state(): @@ -3913,7 +3913,7 @@ def fn(): m = TestModel() x = torch.rand((2, 2)) real = m(x) - graph, _ = torch._dynamo.export(m, x) + graph, _ = torch._dynamo.export(m)(x) dynamo_result = graph(x) self.assertTrue(same(real, dynamo_result)) @@ -3937,7 +3937,7 @@ def fn(): m = TestModel() x = torch.rand((2, 2)) real = m(x) - graph, _ = torch._dynamo.export(m, x) + graph, _ = torch._dynamo.export(m)(x) dynamo_result = graph(x) self.assertTrue(same(real, dynamo_result)) @@ -4724,7 +4724,7 @@ def fn(): b.tag = "b" b.frog = "ribbit" - exported = torch._dynamo.export(foo, a, b) + exported = torch._dynamo.export(foo)(a, b) out_graph = exported[0] nodes = list(out_graph.graph.nodes) @@ -4772,7 +4772,7 @@ def fn(): state[0].tag = "STATE_0" state[1].tag = "HMMM" - exported = torch._dynamo.export(pre_attention_state_ops, i, mems, state) + exported = torch._dynamo.export(pre_attention_state_ops)(i, mems, state) out_graph = exported[0] nodes = list(out_graph.graph.nodes) diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index 8c6830b0f4a..a40f22fed4f 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -1410,7 +1410,7 @@ class NNModuleTests(torch._dynamo.test_case.TestCase): mod = ModuleSpecialFwd() rx = torch.randn([3, 10, 10]) real = mod(rx) - graph, _ = torch._dynamo.export(mod, rx) + graph, _ = torch._dynamo.export(mod)(rx) self.assertTrue(torch._dynamo.testing.same(real, graph(rx))) def test_conv_call_forward_directly(self): diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index fb051f6a72e..7f6254ab477 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -1188,7 +1188,7 @@ class ReproTests(torch._dynamo.test_case.TestCase): self.assertEqual(cnt.frame_count, 2) self.assertEqual(cnt.op_count, 3) # rand, rand try: - graph, _ = torch._dynamo.export(fn) + graph, _ = torch._dynamo.export(fn)() # See https://github.com/pytorch/pytorch/pull/87490 self.fail("unexpected export success") except torch._dynamo.exc.Unsupported: @@ -2650,11 +2650,11 @@ class ReproTests(torch._dynamo.test_case.TestCase): self.assertEqual(cnt.op_count, 6) self.assertEqual(cnt.frame_count, 1) - exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5])) + exported, _ = torch._dynamo.export(f)(torch.Tensor([3, 4, 5])) self.assertTrue(same(exported(*args), f(*args))) with self.assertRaisesRegex(RuntimeError, "First dim need to be 3"): - exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5])) + exported, _ = torch._dynamo.export(f)(torch.Tensor([4, 4, 5])) def test_not_rewrite_assert_for_other_errors(self): def f(x): @@ -2686,11 +2686,11 @@ class ReproTests(torch._dynamo.test_case.TestCase): return x.cos() + b args = (torch.Tensor([3, 4, 5]),) - exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5])) + exported, _ = torch._dynamo.export(f)(torch.Tensor([3, 4, 5])) self.assertTrue(same(exported(*args), f(*args))) with self.assertRaisesRegex(RuntimeError, "assertion error"): - exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5])) + exported, _ = torch._dynamo.export(f)(torch.Tensor([4, 4, 5])) def test_rewrite_assert_with_non_string_msg(self): def f(x): @@ -2718,7 +2718,7 @@ class ReproTests(torch._dynamo.test_case.TestCase): return x.cos() + b args = (torch.Tensor([3, 4, 5]),) - exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5])) + exported, _ = torch._dynamo.export(f)(torch.Tensor([3, 4, 5])) self.assertTrue(same(exported(*args), f(*args))) cnt = torch._dynamo.testing.CompileCounter() @@ -2728,7 +2728,7 @@ class ReproTests(torch._dynamo.test_case.TestCase): self.assertEqual(cnt.op_count, 3) self.assertEqual(cnt.frame_count, 1) - exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5])) + exported, _ = torch._dynamo.export(f)(torch.Tensor([4, 4, 5])) self.assertTrue(same(exported(*args), f(*args))) def test_size_typematch(self): @@ -2885,7 +2885,7 @@ class ReproTests(torch._dynamo.test_case.TestCase): inp = torch.randn(6, 5) - gm, _ = torch._dynamo.export(f, torch.randn(4, 5), aten_graph=True) + gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.randn(4, 5)) self.assertEqual(gm(inp).shape, f(inp).shape) @torch._dynamo.config.patch("specialize_int", False) @@ -3058,9 +3058,10 @@ class ReproTests(torch._dynamo.test_case.TestCase): gm, _ = torch._dynamo.export( f, + aten_graph=True, + )( torch.zeros(6, 4), torch.tensor(1), - aten_graph=True, ) self.assertEqual( f(torch.zeros(6, 4), torch.tensor(1)), @@ -3158,8 +3159,9 @@ class ReproTests(torch._dynamo.test_case.TestCase): gm, _ = torch._dynamo.export( f, - torch.zeros(6, 4), aten_graph=True, + )( + torch.zeros(6, 4), ) self.assertEqual(f(torch.ones(8, 4)), gm(torch.ones(8, 4))) diff --git a/test/fx/test_source_matcher_utils.py b/test/fx/test_source_matcher_utils.py index dd6ccb77b25..3864eca829c 100644 --- a/test/fx/test_source_matcher_utils.py +++ b/test/fx/test_source_matcher_utils.py @@ -30,7 +30,7 @@ class TestSourceMatcher(JitTestCase): return x inputs = (torch.randn(3, 3),) - gm, _ = torch._dynamo.export(M(), *inputs, aten_graph=True) + gm, _ = torch._dynamo.export(M(), aten_graph=True)(*inputs) gm.graph.eliminate_dead_code() module_partitions = get_source_partitions(gm.graph, [torch.nn.Linear, torch.nn.ReLU]) @@ -69,7 +69,7 @@ class TestSourceMatcher(JitTestCase): return self.maxpool(self.relu(z)) inputs = (torch.randn(1, 3, 256, 256),) - gm, _ = torch._dynamo.export(M(torch.ones(1, 16, 256, 256)), *inputs, aten_graph=True) + gm, _ = torch._dynamo.export(M(torch.ones(1, 16, 256, 256)), aten_graph=True)(*inputs) gm.graph.eliminate_dead_code() module_partitions = get_source_partitions(gm.graph, [torch.nn.Conv2d, torch.nn.ReLU, torch.nn.MaxPool2d]) @@ -111,7 +111,7 @@ class TestSourceMatcher(JitTestCase): return x inputs = (torch.randn(1, 3, 5, 5), torch.rand(3, 3, 3, 3), torch.rand(3)) - gm, _ = torch._dynamo.export(M(), *inputs, aten_graph=True) + gm, _ = torch._dynamo.export(M(), aten_graph=True)(*inputs) gm.graph.eliminate_dead_code() module_partitions = get_source_partitions(gm.graph, [torch.nn.functional.conv2d]) @@ -135,7 +135,7 @@ class TestSourceMatcher(JitTestCase): return x inputs = (torch.randn(1, 5), torch.rand((5, 5)), torch.zeros(5)) - gm, _ = torch._dynamo.export(M(), *inputs, aten_graph=True) + gm, _ = torch._dynamo.export(M(), aten_graph=True)(*inputs) gm.graph.eliminate_dead_code() module_partitions = get_source_partitions(gm.graph, [torch.nn.functional.linear, torch.nn.functional.relu]) diff --git a/test/onnx/test_fx_passes.py b/test/onnx/test_fx_passes.py index fc507a4c2bb..76823fd6f30 100644 --- a/test/onnx/test_fx_passes.py +++ b/test/onnx/test_fx_passes.py @@ -16,7 +16,7 @@ class TestFxPasses(common_utils.TestCase): x = torch.randn(3) y = torch.randn(3) z = torch.randn(3) - gm, _ = torch._dynamo.export(func, x, y, z) + gm, _ = torch._dynamo.export(func)(x, y, z) torch._dynamo.reset() # Purposely name the nodes in a way that will cause a recursive collision later. @@ -44,7 +44,7 @@ class TestFxPasses(common_utils.TestCase): x = torch.randn(3) y = torch.randn(3) z = torch.randn(3) - gm, _ = torch._dynamo.export(func, x, y, z) + gm, _ = torch._dynamo.export(func)(x, y, z) torch._dynamo.reset() # Run `set_node_name` and verify that the names are correct. diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 3390a71d80f..1e7bc064a34 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -833,7 +833,7 @@ class FlattenInputOutputSignature(torch.fx.interpreter.Transformer): def export( f: Callable[..., Any], - *args, + *extra_args, aten_graph: bool = False, pre_dispatch: bool = False, decomposition_table: Optional[ @@ -843,16 +843,14 @@ def export( constraints: Optional[List[Constraint]] = None, assume_static_by_default: bool = False, fake_mode: fake_tensor.FakeTensorMode = None, - **kwargs, -) -> Tuple[torch.fx.GraphModule, Set[_guards.Guard]]: + **extra_kwargs, +) -> Callable[..., Tuple[torch.fx.GraphModule, Set[_guards.Guard]]]: """ Export an input function f to a format that can be executed outside of PyTorch using the FX graph. Args: f (callable): A PyTorch function to be exported. - *args: Variable length argument list to be passed to the function f. - aten_graph (bool): If True, exports a graph with ATen operators. If False, exports a graph with Python operators. Default is False. @@ -872,10 +870,8 @@ def export( Useful during symbolic tracing, when user input is already fakefied. Implies free fake tensors are allowed on `make_fx`. `fake_mode` must contain a valid (not None) `shape_env` instance. - **kwargs: Arbitrary keyword arguments to be passed to the function f. - Returns: - A tuple of (graph, guards) + A function that given args and kwargs, returns a tuple of (graph, guards) Graph: An FX graph representing the execution of the input PyTorch function with the provided arguments and options. Guards: The guards we accumulated during tracing f above @@ -887,299 +883,329 @@ def export( Note - this headerdoc was authored by ChatGPT, with slight modifications by the author. """ - check_if_dynamo_supported() - torch._C._log_api_usage_once("torch._dynamo.export") - if decomposition_table is not None: - assert ( - aten_graph - ), "Specifying a decomposition_table table or tracing mode is illegal without setting aten_graph=True" - if pre_dispatch: - assert aten_graph, "pre_dispatch=True can only be used when aten_graph=True" - f = innermost_fn(f) - call_to_inspect = f.forward if isinstance(f, torch.nn.Module) else f - original_signature = inspect.signature(call_to_inspect) + # Deal with "local variable referenced before assignment" + _fake_mode = fake_mode + _f = f + _assume_static_by_default = assume_static_by_default - assert ( - not fake_mode or fake_mode.shape_env is not None - ), "The specified fake_mode must contain a valid shape_env" - graph = None - out_guards = None - graph_captured_input = None - graph_captured_result: Optional[Tuple[torch.Tensor, ...]] = None - fake_mode = fake_mode or _guards.detect_fake_mode(args) - _allow_fake_constant: bool = ( - fake_mode is not None - ) # Allow fake constants during symbolic tracing - - def produce_matching(source_args, candidate_args): - matched_elements_positions = [] - dict_of_source_args = dict() - for i in range(0, len(source_args)): - element_id = id(source_args[i]) - dict_of_source_args[element_id] = i - - for i in range(0, len(candidate_args)): - arg = candidate_args[i] - # 1-element tensor arg can be unspec int/float - if isinstance(arg, torch.Tensor) and torch.numel(arg) == 1: - if id(arg) in dict_of_source_args: - matched_elements_positions.append(dict_of_source_args[id(arg)]) - elif id(arg.item()) in dict_of_source_args: - matched_elements_positions.append( - dict_of_source_args[id(arg.item())] - ) - else: - raise AssertionError( - "Dynamo input/output is not consistent with traced input/output" - ) - else: - assert ( - id(arg) in dict_of_source_args - ), "Dynamo input and output is a strict subset of traced input/output" - matched_elements_positions.append(dict_of_source_args[id(arg)]) - - return matched_elements_positions - - def guard_export_print(guards: Set[_guards.Guard]): - nonlocal out_guards - assert out_guards is None, "whole graph export entails exactly one guard export" - out_guards = guards - - example_inputs = [] - - def dynamo_normalization_capturing_compiler( - gm: torch.fx.GraphModule, inner_example_inputs - ): - nonlocal graph - assert ( - graph is None - ), "Tried to emit a second graph during export. Tracing through 'f' must produce a single graph." - graph = gm - - nonlocal fake_mode, example_inputs - fake_mode = fake_mode or _guards.detect_fake_mode(inner_example_inputs) - example_inputs = inner_example_inputs - - def result_capturing_wrapper(*graph_inputs): - nonlocal graph_captured_result - nonlocal graph_captured_input - - graph_captured_input = graph_inputs - assert graph is not None - graph_captured_result = graph(*graph_inputs) - return graph_captured_result - - return result_capturing_wrapper - - flat_args, in_spec = pytree.tree_flatten((args, kwargs)) - - remove_from_cache(f) - constraint_violation_error = None - if tracing_mode != "symbolic": - assume_static_by_default = True - with patch(f"{__name__}.most_recent_backend", None), config.patch( - specialize_int=True, - assume_static_by_default=assume_static_by_default, - automatic_dynamic_shapes=False, - capture_dynamic_output_shape_ops=True, - capture_scalar_outputs=True, - ), torch._guards.export_fake_mode(fake_mode): - opt_f = optimize_assert( - dynamo_normalization_capturing_compiler, - hooks=Hooks( - guard_export_fn=guard_export_print, - guard_fail_fn=None, - ), - export=True, - export_constraints=constraints, - )(f) - # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideffects and reject. - try: - result_traced = opt_f(*args, **kwargs) - except ConstraintViolationError as e: - constraint_violation_error = e - remove_from_cache(f) - - if ( - (shape_env := getattr(fake_mode, "shape_env", None)) is not None - and (dim_constraints := shape_env.dim_constraints) is not None - and not skipfiles.check(inspect.getsourcefile(call_to_inspect)) - ): - dim_constraints.solve() - msg = dim_constraints.prettify_results(original_signature) - forced_specializations = dim_constraints.forced_specializations() - if forced_specializations: - msg = ( - "Some dynamic dimensions need to be specialized because " - "the constraints inferred for them are too complex to specify.\n" - f"{forced_specializations}\n{msg}" - ) - if constraint_violation_error: - constraint_violation_error.args = ( - constraint_violation_error.args[0] + msg, - ) - else: - if forced_specializations: - constraint_violation_error = ConstraintViolationError(msg) - else: - log.info( - "Summary of dimension constraints:%s", - msg, - ) - - # Error if we have any constraints on static values - for k in shape_env.var_to_range.keys(): - if isinstance(k, sympy.Integer): - constraint_violation_error = ConstraintViolationError( - f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n" - "It appears that you're trying to set a constraint on a " - f"value which we evaluated to have a static value of {k}. " - "Scroll up to see where this constraint was set." - ) - if constraint_violation_error: - raise constraint_violation_error - - assert ( - graph is not None - ), "Failed to produce a graph during tracing. Tracing through 'f' must produce a single graph." - assert out_guards is not None, "Failed to produce guards during tracing" - assert fake_mode is not None - - matched_input_elements_positions = produce_matching(flat_args, graph_captured_input) - - # NB: This is mostly hitting the cache; Dynamo already converted these - example_fake_inputs = [fake_mode.from_tensor(t) for t in example_inputs] - flat_results_traced, out_spec_traced = pytree.tree_flatten(result_traced) - - assert graph_captured_result is not None - flat_both = list(graph_captured_result) + flat_args - matched_output_elements_positions = produce_matching(flat_both, flat_results_traced) - - if aten_graph: - # Running graph with interpreter is needed for propagating the stack_trace - def graph_with_interpreter(*args): - with torch.fx.traceback.preserve_node_meta(): - return torch.fx.Interpreter(graph).run(*args) - - with enable_python_dispatcher(), fake_mode: - try: - graph = make_fx( - graph_with_interpreter, - decomposition_table=decomposition_table, - tracing_mode="real", - _allow_non_fake_inputs=True, - pre_dispatch=pre_dispatch, - _allow_fake_constant=_allow_fake_constant, - )(*example_fake_inputs) - except CondOpArgsMismatchError as e: - # Wrap the internal error to the user-facing error - raise UserError(UserErrorType.DYNAMIC_CONTROL_FLOW, str(e)) - - new_graph = FlattenInputOutputSignature( - graph, - flat_args, - matched_input_elements_positions, - matched_output_elements_positions, - example_fake_inputs, - fake_mode, - ).transform() - - # Store constraints and inputs as metadata for user passes, e.g. turn constraints to runtime check - new_graph.meta["input_shape_constraints"] = ( - [constraint.serializable_spec for constraint in constraints] - if constraints - else [] - ) - - def signature_to_fullargspec(sig: inspect.Signature): - # Get a list of Parameter objects from the Signature object - params = list(sig.parameters.values()) - # Separate positional arguments, keyword-only arguments and varargs/varkw - args = [ - p.name for p in params if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD - ] - kwonlyargs = [ - p.name for p in params if p.kind == inspect.Parameter.KEYWORD_ONLY - ] - varargs = next( - (p.name for p in params if p.kind == inspect.Parameter.VAR_POSITIONAL), None - ) - varkw = next( - (p.name for p in params if p.kind == inspect.Parameter.VAR_KEYWORD), None - ) - # Get default values for positional arguments and keyword-only arguments - defaults = tuple( - p.default - for p in params - if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD - and p.default is not inspect.Parameter.empty - ) - kwonlydefaults = { - p.name: p.default - for p in params - if p.kind == inspect.Parameter.KEYWORD_ONLY - and p.default is not inspect.Parameter.empty - } - # Get annotations for parameters and return value - annotations = {} - if sig.return_annotation: - annotations = {"return": sig.return_annotation} - for parameter in params: - annotations[parameter.name] = parameter.annotation - # Return a FullArgSpec object with the extracted attributes - return inspect.FullArgSpec( - args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations - ) - - # Make dynamo graph to have same input/output spec as user code - def argument_names(f: Callable[..., Any], *args, **kwargs) -> List[str]: - fullargspec = signature_to_fullargspec(original_signature) - - # 1. Map `args` 1-to-1 to positional arguments in original signature. - input_strs = fullargspec.args[: len(args)] - - if len(args) > len(fullargspec.args): - # 2. If there are more arguments left in `args`, they map to varargs in original - # signature. Assign names as {varargs}_0, {varargs}_1, ... - assert fullargspec.varargs is not None, "More arguments than expected" - input_strs += [ - f"{fullargspec.varargs}_{i}" - for i in range(0, len(args) - len(input_strs)) - ] - elif len(args) < len(fullargspec.args): - # 3. If there are fewer arguments in `args` than `fullargspec.args`, - # it implies these are arguments either with default values, or provided in - # `kwargs`. The former can be safely ignored. Because Dynamo.export does not - # export them as part of the function signature. The latter will be handled - # in the next step. - for unprovided_arg in fullargspec.args[ - len(args) : -len(fullargspec.defaults or []) - ]: - assert unprovided_arg in kwargs, f"Missing argument {unprovided_arg}" - - # 4. Keyword arguments provided in `kwargs`. - input_strs += list(kwargs.keys()) - - # 5. Keyword-only arguments with default values if not provided are not exported - # as part of the function signature. - for kwonly_arg in fullargspec.kwonlyargs: - kwonlydefaults = fullargspec.kwonlydefaults or {} + def inner(*args, **kwargs): + fake_mode = _fake_mode + f = _f + assume_static_by_default = _assume_static_by_default + check_if_dynamo_supported() + torch._C._log_api_usage_once("torch._dynamo.export") + if decomposition_table is not None: assert ( - kwonly_arg in kwargs or kwonly_arg in kwonlydefaults - ), f"Missing keyword only argument {kwonly_arg}" + aten_graph + ), "Specifying a decomposition_table table or tracing mode is illegal without setting aten_graph=True" + if pre_dispatch: + assert aten_graph, "pre_dispatch=True can only be used when aten_graph=True" + f = innermost_fn(f) + call_to_inspect = f.forward if isinstance(f, torch.nn.Module) else f + original_signature = inspect.signature(call_to_inspect) + assert ( + not fake_mode or fake_mode.shape_env is not None + ), "The specified fake_mode must contain a valid shape_env" + graph = None + out_guards = None + graph_captured_input = None + graph_captured_result: Optional[Tuple[torch.Tensor, ...]] = None + fake_mode = fake_mode or _guards.detect_fake_mode(args) + _allow_fake_constant: bool = ( + fake_mode is not None + ) # Allow fake constants during symbolic tracing - return input_strs + def produce_matching(source_args, candidate_args): + matched_elements_positions = [] + dict_of_source_args = dict() + for i in range(0, len(source_args)): + element_id = id(source_args[i]) + dict_of_source_args[element_id] = i - new_graph.graph._codegen = _PyTreeCodeGen( - _PyTreeInfo( - argument_names(f, *args, **kwargs), - in_spec, - out_spec_traced, + for i in range(0, len(candidate_args)): + arg = candidate_args[i] + # 1-element tensor arg can be unspec int/float + if isinstance(arg, torch.Tensor) and torch.numel(arg) == 1: + if id(arg) in dict_of_source_args: + matched_elements_positions.append(dict_of_source_args[id(arg)]) + elif id(arg.item()) in dict_of_source_args: + matched_elements_positions.append( + dict_of_source_args[id(arg.item())] + ) + else: + raise AssertionError( + "Dynamo input/output is not consistent with traced input/output" + ) + else: + assert ( + id(arg) in dict_of_source_args + ), "Dynamo input and output is a strict subset of traced input/output" + matched_elements_positions.append(dict_of_source_args[id(arg)]) + + return matched_elements_positions + + def guard_export_print(guards: Set[_guards.Guard]): + nonlocal out_guards + assert ( + out_guards is None + ), "whole graph export entails exactly one guard export" + out_guards = guards + + example_inputs = [] + + def dynamo_normalization_capturing_compiler( + gm: torch.fx.GraphModule, inner_example_inputs + ): + nonlocal graph + assert ( + graph is None + ), "Tried to emit a second graph during export. Tracing through 'f' must produce a single graph." + graph = gm + + nonlocal fake_mode, example_inputs + fake_mode = fake_mode or _guards.detect_fake_mode(inner_example_inputs) + example_inputs = inner_example_inputs + + def result_capturing_wrapper(*graph_inputs): + nonlocal graph_captured_result + nonlocal graph_captured_input + + graph_captured_input = graph_inputs + assert graph is not None + graph_captured_result = graph(*graph_inputs) + return graph_captured_result + + return result_capturing_wrapper + + flat_args, in_spec = pytree.tree_flatten((args, kwargs)) + + remove_from_cache(f) + constraint_violation_error = None + if tracing_mode != "symbolic": + assume_static_by_default = True + with patch(f"{__name__}.most_recent_backend", None), config.patch( + specialize_int=True, + assume_static_by_default=assume_static_by_default, + automatic_dynamic_shapes=False, + capture_dynamic_output_shape_ops=True, + capture_scalar_outputs=True, + ), torch._guards.export_fake_mode(fake_mode): + opt_f = optimize_assert( + dynamo_normalization_capturing_compiler, + hooks=Hooks( + guard_export_fn=guard_export_print, + guard_fail_fn=None, + ), + export=True, + export_constraints=constraints, + )(f) + # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideffects and reject. + try: + result_traced = opt_f(*args, **kwargs) + except ConstraintViolationError as e: + constraint_violation_error = e + remove_from_cache(f) + + if ( + (shape_env := getattr(fake_mode, "shape_env", None)) is not None + and (dim_constraints := shape_env.dim_constraints) is not None + and not skipfiles.check(inspect.getsourcefile(call_to_inspect)) + ): + dim_constraints.solve() + msg = dim_constraints.prettify_results(original_signature) + forced_specializations = dim_constraints.forced_specializations() + if forced_specializations: + msg = ( + "Some dynamic dimensions need to be specialized because " + "the constraints inferred for them are too complex to specify.\n" + f"{forced_specializations}\n{msg}" + ) + if constraint_violation_error: + constraint_violation_error.args = ( + constraint_violation_error.args[0] + msg, + ) + else: + if forced_specializations: + constraint_violation_error = ConstraintViolationError(msg) + else: + log.info( + "Summary of dimension constraints:%s", + msg, + ) + + # Error if we have any constraints on static values + for k in shape_env.var_to_range.keys(): + if isinstance(k, sympy.Integer): + constraint_violation_error = ConstraintViolationError( + f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n" + "It appears that you're trying to set a constraint on a " + f"value which we evaluated to have a static value of {k}. " + "Scroll up to see where this constraint was set." + ) + if constraint_violation_error: + raise constraint_violation_error + + assert ( + graph is not None + ), "Failed to produce a graph during tracing. Tracing through 'f' must produce a single graph." + assert out_guards is not None, "Failed to produce guards during tracing" + assert fake_mode is not None + + matched_input_elements_positions = produce_matching( + flat_args, graph_captured_input ) - ) - new_graph.recompile() - return (new_graph, out_guards) + # NB: This is mostly hitting the cache; Dynamo already converted these + example_fake_inputs = [fake_mode.from_tensor(t) for t in example_inputs] + flat_results_traced, out_spec_traced = pytree.tree_flatten(result_traced) + + assert graph_captured_result is not None + flat_both = list(graph_captured_result) + flat_args + matched_output_elements_positions = produce_matching( + flat_both, flat_results_traced + ) + + if aten_graph: + # Running graph with interpreter is needed for propagating the stack_trace + def graph_with_interpreter(*args): + with torch.fx.traceback.preserve_node_meta(): + return torch.fx.Interpreter(graph).run(*args) + + with enable_python_dispatcher(), fake_mode: + try: + graph = make_fx( + graph_with_interpreter, + decomposition_table=decomposition_table, + tracing_mode="real", + _allow_non_fake_inputs=True, + pre_dispatch=pre_dispatch, + _allow_fake_constant=_allow_fake_constant, + )(*example_fake_inputs) + except CondOpArgsMismatchError as e: + # Wrap the internal error to the user-facing error + raise UserError(UserErrorType.DYNAMIC_CONTROL_FLOW, str(e)) + + new_graph = FlattenInputOutputSignature( + graph, + flat_args, + matched_input_elements_positions, + matched_output_elements_positions, + example_fake_inputs, + fake_mode, + ).transform() + + # Store constraints and inputs as metadata for user passes, e.g. turn constraints to runtime check + new_graph.meta["input_shape_constraints"] = ( + [constraint.serializable_spec for constraint in constraints] + if constraints + else [] + ) + + def signature_to_fullargspec(sig: inspect.Signature): + # Get a list of Parameter objects from the Signature object + params = list(sig.parameters.values()) + # Separate positional arguments, keyword-only arguments and varargs/varkw + args = [ + p.name + for p in params + if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + ] + kwonlyargs = [ + p.name for p in params if p.kind == inspect.Parameter.KEYWORD_ONLY + ] + varargs = next( + (p.name for p in params if p.kind == inspect.Parameter.VAR_POSITIONAL), + None, + ) + varkw = next( + (p.name for p in params if p.kind == inspect.Parameter.VAR_KEYWORD), + None, + ) + # Get default values for positional arguments and keyword-only arguments + defaults = tuple( + p.default + for p in params + if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + and p.default is not inspect.Parameter.empty + ) + kwonlydefaults = { + p.name: p.default + for p in params + if p.kind == inspect.Parameter.KEYWORD_ONLY + and p.default is not inspect.Parameter.empty + } + # Get annotations for parameters and return value + annotations = {} + if sig.return_annotation: + annotations = {"return": sig.return_annotation} + for parameter in params: + annotations[parameter.name] = parameter.annotation + # Return a FullArgSpec object with the extracted attributes + return inspect.FullArgSpec( + args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations + ) + + # Make dynamo graph to have same input/output spec as user code + def argument_names(f: Callable[..., Any], *args, **kwargs) -> List[str]: + fullargspec = signature_to_fullargspec(original_signature) + + # 1. Map `args` 1-to-1 to positional arguments in original signature. + input_strs = fullargspec.args[: len(args)] + + if len(args) > len(fullargspec.args): + # 2. If there are more arguments left in `args`, they map to varargs in original + # signature. Assign names as {varargs}_0, {varargs}_1, ... + assert fullargspec.varargs is not None, "More arguments than expected" + input_strs += [ + f"{fullargspec.varargs}_{i}" + for i in range(0, len(args) - len(input_strs)) + ] + elif len(args) < len(fullargspec.args): + # 3. If there are fewer arguments in `args` than `fullargspec.args`, + # it implies these are arguments either with default values, or provided in + # `kwargs`. The former can be safely ignored. Because Dynamo.export does not + # export them as part of the function signature. The latter will be handled + # in the next step. + for unprovided_arg in fullargspec.args[ + len(args) : -len(fullargspec.defaults or []) + ]: + assert ( + unprovided_arg in kwargs + ), f"Missing argument {unprovided_arg}" + + # 4. Keyword arguments provided in `kwargs`. + input_strs += list(kwargs.keys()) + + # 5. Keyword-only arguments with default values if not provided are not exported + # as part of the function signature. + for kwonly_arg in fullargspec.kwonlyargs: + kwonlydefaults = fullargspec.kwonlydefaults or {} + assert ( + kwonly_arg in kwargs or kwonly_arg in kwonlydefaults + ), f"Missing keyword only argument {kwonly_arg}" + + return input_strs + + new_graph.graph._codegen = _PyTreeCodeGen( + _PyTreeInfo( + argument_names(f, *args, **kwargs), + in_spec, + out_spec_traced, + ) + ) + + new_graph.recompile() + return (new_graph, out_guards) + + if extra_args or extra_kwargs: + warnings.warn( + "export(f, *args, **kwargs) is deprecated, use export(f)(*args, **kwargs) instead. " + "If you don't migrate, we may break your export call in the future if your user defined kwargs " + "conflict with future kwargs added to export(f)." + ) + return inner(*extra_args, **extra_kwargs) + else: + return inner def optimize_assert( diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py index 9c1e4fea887..05ac4c3a5b6 100644 --- a/torch/_export/__init__.py +++ b/torch/_export/__init__.py @@ -62,7 +62,6 @@ from .passes.add_runtime_assertions_for_constraints_pass import _AddRuntimeAsser # # result = torch._dynamo.export( # my_model, -# *sixtyfour_tensors, # constraints=[ # # if you do only dynamic_dim, this is sugar for # # -Inf <= dynamic_dim(blah, 0) <= Inf; we don’t otherwise @@ -74,6 +73,8 @@ from .passes.add_runtime_assertions_for_constraints_pass import _AddRuntimeAsser # # NB: But we actually truncate ranges to be >= 2, because of # # 0/1 specialization # ] +# )( +# *sixtyfour_tensors, # ) def dynamic_dim(t: torch.Tensor, index: int): if not isinstance(t, torch.Tensor): @@ -152,10 +153,11 @@ def export( try: gm_torch_level, _ = torch._dynamo.export( f, - *args, constraints=constraints, assume_static_by_default=True, tracing_mode="symbolic", + )( + *args, **kwargs, ) diff --git a/torch/ao/quantization/pt2e/utils.py b/torch/ao/quantization/pt2e/utils.py index 494604506c7..2dc244da273 100644 --- a/torch/ao/quantization/pt2e/utils.py +++ b/torch/ao/quantization/pt2e/utils.py @@ -141,9 +141,10 @@ def get_aten_graph_module( import torch._dynamo aten_pattern, _ = torch._dynamo.export( pattern, - *copy.deepcopy(example_inputs), aten_graph=True, tracing_mode="real", + )( + *copy.deepcopy(example_inputs), **kwargs, ) aten_pattern.graph.eliminate_dead_code() diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer.py b/torch/ao/quantization/quantizer/xnnpack_quantizer.py index bd6df9aa5b1..f4574685e97 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer.py @@ -64,7 +64,7 @@ def _mark_nodes_as_annotated(nodes: List[Node]): def _get_dynamo_graph(function: Callable, inputs) -> torch.fx.Graph: - gm, _ = torchdynamo.export(function, *inputs, aten_graph=True) + gm, _ = torchdynamo.export(function, aten_graph=True)(*inputs) gm.graph.eliminate_dead_code() return gm.graph diff --git a/torch/onnx/_internal/fx/dynamo_graph_extractor.py b/torch/onnx/_internal/fx/dynamo_graph_extractor.py index 1b4a2d250ea..da898c0aada 100644 --- a/torch/onnx/_internal/fx/dynamo_graph_extractor.py +++ b/torch/onnx/_internal/fx/dynamo_graph_extractor.py @@ -191,9 +191,10 @@ class DynamoExport(exporter.FXGraphExtractor): fx_mode = "symbolic" if options.dynamic_shapes else "fake" graph_module, graph_guard = torch._dynamo.export( wrapped_model, - *model_args, tracing_mode=fx_mode, fake_mode=fake_mode, # type: ignore[arg-type] + )( + *model_args, **model_kwargs, ) del graph_guard # Unused diff --git a/torch/onnx/_internal/fx/passes/modularization.py b/torch/onnx/_internal/fx/passes/modularization.py index 9795883a448..27722650dae 100644 --- a/torch/onnx/_internal/fx/passes/modularization.py +++ b/torch/onnx/_internal/fx/passes/modularization.py @@ -798,7 +798,7 @@ class Modularize(_pass.Transform): >>> out = self.linear(out) >>> return out >>> - >>> gm, _ = torch._dynamo.export(TestModule(), torch.tensor([0, 1, 2]), aten_graph=True) + >>> gm, _ = torch._dynamo.export(TestModule(), aten_graph=True)(torch.tensor([0, 1, 2])) >>> gm.print_readable() >>> gm = passes.Modularize(infra.DiagnosticContext("test_context", "1.0"), gm).run()