mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Added inference to context when only compiling forwards (#83783)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83783 Approved by: https://github.com/pyjhzwh, https://github.com/jansel
This commit is contained in:
parent
08c03c91d7
commit
f45cd00d7a
3 changed files with 36 additions and 8 deletions
|
|
@ -184,7 +184,8 @@ def _reshape_alias(x, shape, strides):
|
|||
return aten.view(x, shape)
|
||||
|
||||
|
||||
graph_being_compiled: str = None
|
||||
# This is a list since looking forward, we can have this arbitrarily nested.
|
||||
graph_being_compiled: List[str] = []
|
||||
nth_graph: int = 0
|
||||
model_name: str = "model"
|
||||
|
||||
|
|
@ -194,23 +195,30 @@ def set_model_name(name):
|
|||
model_name = name
|
||||
|
||||
|
||||
def get_graph_being_compiled() -> str:
|
||||
def get_aot_compilation_context() -> Tuple[List[str], str, int]:
|
||||
return list(graph_being_compiled), model_name, nth_graph
|
||||
|
||||
|
||||
def get_aot_graph_name() -> str:
|
||||
"""
|
||||
Returns the name of the graph being compiled.
|
||||
"""
|
||||
global model_name, graph_being_compiled, nth_graph
|
||||
return f"{model_name}_{graph_being_compiled}_{nth_graph}"
|
||||
return f"{model_name}_{'_'.join(graph_being_compiled)}_{nth_graph}"
|
||||
|
||||
|
||||
get_graph_being_compiled = get_aot_graph_name
|
||||
|
||||
|
||||
@contextmanager
|
||||
def track_graph_compiling(graph_name, increment_index=False):
|
||||
global graph_being_compiled
|
||||
graph_being_compiled = graph_name
|
||||
graph_being_compiled = [graph_name]
|
||||
yield
|
||||
if increment_index:
|
||||
global nth_graph
|
||||
nth_graph += 1
|
||||
graph_being_compiled = None
|
||||
graph_being_compiled = []
|
||||
|
||||
|
||||
def make_boxed_func(f):
|
||||
|
|
@ -264,7 +272,7 @@ class AOTConfig:
|
|||
|
||||
def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig):
|
||||
fw_module = make_fx(flat_fn, aot_config.decompositions)(*flat_args)
|
||||
with track_graph_compiling("forward"):
|
||||
with track_graph_compiling("inference"):
|
||||
compiled_fw = aot_config.fw_compiler(fw_module, flat_args)
|
||||
|
||||
@wraps(compiled_fw)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ from .._src.aot_autograd import (
|
|||
clear_compile_cache,
|
||||
aot_module_simplified,
|
||||
get_graph_being_compiled,
|
||||
get_aot_graph_name,
|
||||
get_aot_compilation_context,
|
||||
make_boxed_func,
|
||||
make_boxed_compiler
|
||||
)
|
||||
|
|
|
|||
|
|
@ -23,8 +23,10 @@ from functorch import (
|
|||
from functorch._src.aot_autograd import aot_module_simplified
|
||||
from functorch.compile import (
|
||||
nnc_jit, compiled_function, compiled_module,
|
||||
min_cut_rematerialization_partition, aot_function, aot_module, decomposition_table, nop,
|
||||
num_of_recompilations, default_partition, default_decompositions, memory_efficient_fusion, clear_compile_cache
|
||||
min_cut_rematerialization_partition, aot_function, aot_module,
|
||||
decomposition_table, nop,
|
||||
num_of_recompilations, default_partition, default_decompositions,
|
||||
memory_efficient_fusion, clear_compile_cache, get_aot_compilation_context
|
||||
)
|
||||
|
||||
from torch.testing._internal.common_device_type import ops
|
||||
|
|
@ -330,6 +332,22 @@ class TestAOTAutograd(AOTTestCase):
|
|||
inp = [torch.randn(5, requires_grad=True) for _ in range(3)]
|
||||
f(*inp).sum().backward()
|
||||
|
||||
def test_compilation_context(self):
|
||||
def f(x):
|
||||
return x.sin().sin()
|
||||
count = []
|
||||
|
||||
def compiler(fx_g, _):
|
||||
context = get_aot_compilation_context()
|
||||
count.append((context[0], len(fx_g.graph.nodes)))
|
||||
return fx_g
|
||||
|
||||
f = aot_function(f, compiler)
|
||||
out = f(torch.randn(5, requires_grad=True))
|
||||
f(torch.randn(5))
|
||||
out.sum().backward()
|
||||
self.assertEqual(count, [(['forward'], 4), (['inference'], 4), (['backward'], 8)])
|
||||
|
||||
|
||||
|
||||
class TestEagerFusionOpInfo(AOTTestCase):
|
||||
|
|
|
|||
Loading…
Reference in a new issue