Added inference to context when only compiling forwards (#83783)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83783
Approved by: https://github.com/pyjhzwh, https://github.com/jansel
This commit is contained in:
Horace He 2022-08-20 01:46:32 +00:00 committed by PyTorch MergeBot
parent 08c03c91d7
commit f45cd00d7a
3 changed files with 36 additions and 8 deletions

View file

@ -184,7 +184,8 @@ def _reshape_alias(x, shape, strides):
return aten.view(x, shape)
graph_being_compiled: str = None
# This is a list since looking forward, we can have this arbitrarily nested.
graph_being_compiled: List[str] = []
nth_graph: int = 0
model_name: str = "model"
@ -194,23 +195,30 @@ def set_model_name(name):
model_name = name
def get_graph_being_compiled() -> str:
def get_aot_compilation_context() -> Tuple[List[str], str, int]:
return list(graph_being_compiled), model_name, nth_graph
def get_aot_graph_name() -> str:
"""
Returns the name of the graph being compiled.
"""
global model_name, graph_being_compiled, nth_graph
return f"{model_name}_{graph_being_compiled}_{nth_graph}"
return f"{model_name}_{'_'.join(graph_being_compiled)}_{nth_graph}"
get_graph_being_compiled = get_aot_graph_name
@contextmanager
def track_graph_compiling(graph_name, increment_index=False):
global graph_being_compiled
graph_being_compiled = graph_name
graph_being_compiled = [graph_name]
yield
if increment_index:
global nth_graph
nth_graph += 1
graph_being_compiled = None
graph_being_compiled = []
def make_boxed_func(f):
@ -264,7 +272,7 @@ class AOTConfig:
def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig):
fw_module = make_fx(flat_fn, aot_config.decompositions)(*flat_args)
with track_graph_compiling("forward"):
with track_graph_compiling("inference"):
compiled_fw = aot_config.fw_compiler(fw_module, flat_args)
@wraps(compiled_fw)

View file

@ -10,6 +10,8 @@ from .._src.aot_autograd import (
clear_compile_cache,
aot_module_simplified,
get_graph_being_compiled,
get_aot_graph_name,
get_aot_compilation_context,
make_boxed_func,
make_boxed_compiler
)

View file

@ -23,8 +23,10 @@ from functorch import (
from functorch._src.aot_autograd import aot_module_simplified
from functorch.compile import (
nnc_jit, compiled_function, compiled_module,
min_cut_rematerialization_partition, aot_function, aot_module, decomposition_table, nop,
num_of_recompilations, default_partition, default_decompositions, memory_efficient_fusion, clear_compile_cache
min_cut_rematerialization_partition, aot_function, aot_module,
decomposition_table, nop,
num_of_recompilations, default_partition, default_decompositions,
memory_efficient_fusion, clear_compile_cache, get_aot_compilation_context
)
from torch.testing._internal.common_device_type import ops
@ -330,6 +332,22 @@ class TestAOTAutograd(AOTTestCase):
inp = [torch.randn(5, requires_grad=True) for _ in range(3)]
f(*inp).sum().backward()
def test_compilation_context(self):
def f(x):
return x.sin().sin()
count = []
def compiler(fx_g, _):
context = get_aot_compilation_context()
count.append((context[0], len(fx_g.graph.nodes)))
return fx_g
f = aot_function(f, compiler)
out = f(torch.randn(5, requires_grad=True))
f(torch.randn(5))
out.sum().backward()
self.assertEqual(count, [(['forward'], 4), (['inference'], 4), (['backward'], 8)])
class TestEagerFusionOpInfo(AOTTestCase):