diff --git a/functorch/functorch/_src/aot_autograd.py b/functorch/functorch/_src/aot_autograd.py index 743ac584843..97a86913d8d 100644 --- a/functorch/functorch/_src/aot_autograd.py +++ b/functorch/functorch/_src/aot_autograd.py @@ -184,7 +184,8 @@ def _reshape_alias(x, shape, strides): return aten.view(x, shape) -graph_being_compiled: str = None +# This is a list since looking forward, we can have this arbitrarily nested. +graph_being_compiled: List[str] = [] nth_graph: int = 0 model_name: str = "model" @@ -194,23 +195,30 @@ def set_model_name(name): model_name = name -def get_graph_being_compiled() -> str: +def get_aot_compilation_context() -> Tuple[List[str], str, int]: + return list(graph_being_compiled), model_name, nth_graph + + +def get_aot_graph_name() -> str: """ Returns the name of the graph being compiled. """ global model_name, graph_being_compiled, nth_graph - return f"{model_name}_{graph_being_compiled}_{nth_graph}" + return f"{model_name}_{'_'.join(graph_being_compiled)}_{nth_graph}" + + +get_graph_being_compiled = get_aot_graph_name @contextmanager def track_graph_compiling(graph_name, increment_index=False): global graph_being_compiled - graph_being_compiled = graph_name + graph_being_compiled = [graph_name] yield if increment_index: global nth_graph nth_graph += 1 - graph_being_compiled = None + graph_being_compiled = [] def make_boxed_func(f): @@ -264,7 +272,7 @@ class AOTConfig: def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig): fw_module = make_fx(flat_fn, aot_config.decompositions)(*flat_args) - with track_graph_compiling("forward"): + with track_graph_compiling("inference"): compiled_fw = aot_config.fw_compiler(fw_module, flat_args) @wraps(compiled_fw) diff --git a/functorch/functorch/compile/__init__.py b/functorch/functorch/compile/__init__.py index 1568d5687b9..99e0456a4e4 100644 --- a/functorch/functorch/compile/__init__.py +++ b/functorch/functorch/compile/__init__.py @@ -10,6 +10,8 @@ from .._src.aot_autograd import ( clear_compile_cache, aot_module_simplified, get_graph_being_compiled, + get_aot_graph_name, + get_aot_compilation_context, make_boxed_func, make_boxed_compiler ) diff --git a/functorch/test/test_pythonkey.py b/functorch/test/test_pythonkey.py index 5deeac1eb27..e1d4b3c4ccd 100644 --- a/functorch/test/test_pythonkey.py +++ b/functorch/test/test_pythonkey.py @@ -23,8 +23,10 @@ from functorch import ( from functorch._src.aot_autograd import aot_module_simplified from functorch.compile import ( nnc_jit, compiled_function, compiled_module, - min_cut_rematerialization_partition, aot_function, aot_module, decomposition_table, nop, - num_of_recompilations, default_partition, default_decompositions, memory_efficient_fusion, clear_compile_cache + min_cut_rematerialization_partition, aot_function, aot_module, + decomposition_table, nop, + num_of_recompilations, default_partition, default_decompositions, + memory_efficient_fusion, clear_compile_cache, get_aot_compilation_context ) from torch.testing._internal.common_device_type import ops @@ -330,6 +332,22 @@ class TestAOTAutograd(AOTTestCase): inp = [torch.randn(5, requires_grad=True) for _ in range(3)] f(*inp).sum().backward() + def test_compilation_context(self): + def f(x): + return x.sin().sin() + count = [] + + def compiler(fx_g, _): + context = get_aot_compilation_context() + count.append((context[0], len(fx_g.graph.nodes))) + return fx_g + + f = aot_function(f, compiler) + out = f(torch.randn(5, requires_grad=True)) + f(torch.randn(5)) + out.sum().backward() + self.assertEqual(count, [(['forward'], 4), (['inference'], 4), (['backward'], 8)]) + class TestEagerFusionOpInfo(AOTTestCase):